Я пытаюсь запустить https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb локально, на MacBook M2 с Sonoma 14.5. Однако на шаге 11 я продолжаю сталкиваться со следующей ошибкой:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[75], line 1
----> 1 masks = mask_generator.generate(image)
File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/segment_anything/automatic_mask_generator.py:163, in SamAutomaticMaskGenerator.generate(self, image)
138 """
139 Generates masks for the given image.
140
(...)
159 the mask, given in XYWH format.
160 """
162 # Generate masks
--> 163 mask_data = self._generate_masks(image)
165 # Filter small disconnected regions and holes in masks
166 if self.min_mask_region_area > 0:
File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/segment_anything/automatic_mask_generator.py:206, in SamAutomaticMaskGenerator._generate_masks(self, image)
204 data = MaskData()
205 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 206 crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207 data.cat(crop_data)
209 # Remove duplicate masks between crops
File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/segment_anything/automatic_mask_generator.py:236, in SamAutomaticMaskGenerator._process_crop(self, image, crop_box, crop_layer_idx, orig_size)
234 cropped_im = image[y0:y1, x0:x1, :]
235 cropped_im_size = cropped_im.shape[:2]
--> 236 self.predictor.set_image(cropped_im)
238 # Get points for this crop
239 points_scale = np.array(cropped_im_size)[None, ::-1]
File ~/opt/anaconda3/envs/ve_env/lib/python3.9/site-packages/segment_anything/predictor.py:57, in SamPredictor.set_image(self, image, image_format)
55 # Transform the image to the form expected by the model
56 input_image = self.transform.apply_image(image)
---> 57 input_image_torch = torch.as_tensor(input_image, device=self.device)
58 input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
60 self.set_torch_image(input_image_torch, image.shape[:2])
RuntimeError: Could not infer dtype of numpy.uint8
Я использую среду conda с Python 3.9.19, а также тестировался с Python 3.11. Основываясь на онлайн-комментариях, я подозревал, что это проблема с numpy-версиями, но, попробовав несколько версий, я не могу найти правильную комбинацию. В настоящее время я пытаюсь сделать следующее:
numpy==1.24.4
torch==1.9.0
torchvision==0.10.0
opencv-python==4.10.0.84
Запуск того же блокнота в Google Colab работает нормально, и там указаны следующие версии:
import numpy as np
import torch
import cv2
print(np.__version__)
print(torch.__version__)
print(cv2.__version__)
1.25.2
2.3.0+cu121
4.8.0
Здесь используется Python 3.10.12. Эти версии недоступны на Mac, поэтому я застрял.
Как узнать, почему numpy.uint8 не распознается, и как исправить эту ошибку? Большинство онлайн-комментариев указывают на обновление numpy, но я безуспешно пробовал несколько версий numpy. Любая помощь приветствуется.
🤔 А знаете ли вы, что...
С Python можно создавать веб-скраперы для извлечения данных из веб-сайтов.
Для кого-то еще, кто столкнулся с той же проблемой, причиной может быть проблема с поддержкой Jupyter Notebook в JetBrains PyCharm. Я также отправляю им отчет об ошибке. Запуск jupyter-notebook извне показал, что используется правильная версия numpy, и код работает как положено.