DEV Community

future-158
future-158

Posted on

simply run segment anything on single image

segmentation anything model is popular lately.
to try it's interacte usage, i made simple matplotlib script.

Below script use webagg backend, so after you run script new brower tab will open.
you can draw bounding box with mouse dragging, after you release mouse, automatically sam model will run on it

installation

conda create --name sam pip python=3.10
conda activate sam
conda install -y pytorch torchvision torchaudio -c pytorch
pip install transformers
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install datasets
pip install matplotlib
pip install tornado
pip install jupyter
pip install notebook
Enter fullscreen mode Exit fullscreen mode

load library

import matplotlib
import numpy as np
from datasets import load_dataset

matplotlib.use('WebAgg')
import random

import cv2
import datasets
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.widgets import RectangleSelector
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry
Enter fullscreen mode Exit fullscreen mode

download weight

wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Enter fullscreen mode Exit fullscreen mode

load model

#  %% load model
sam_checkpoint = "sam_vit_b_01ec64.pth"

# for m1 mac
model_type = "vit_b"
device = "mps"

# # use linux with gpu
# model_type = "vit_h"
# device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)


Enter fullscreen mode Exit fullscreen mode

load image

ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017", data_dir="./dummy_data/", split='train[:10]')
example = ds[0]
img = Image.open(example['image_path'])
image = np.array(img)
img



Enter fullscreen mode Exit fullscreen mode

run single image


predictor.set_image(image)

# %% select function
def on_select(eclick, erelease):
    x1, y1 = eclick.xdata, eclick.ydata
    x2, y2 = erelease.xdata, erelease.ydata
    print(f"Rectangle selected from ({x1}, {y1}) to ({x2}, {y2})")

    input_box = np.array([x1,y1,x2,y2])
    masks, scores, logits = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=False,
    )
    mask = masks[0]
    color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax = plt.gca()
    ax.imshow(mask_image)


plt.close('all')
fig, ax = plt.subplots()
ax.imshow(image)
rect_selector = RectangleSelector(ax, on_select, useblit=True, button=[1], minspanx=5, minspany=5, spancoords='pixels', interactive=True, 
)
plt.show()


Enter fullscreen mode Exit fullscreen mode

Enter fullscreen mode Exit fullscreen mode

Top comments (0)