Note
Click here to download the full example code
01. Segment anything
The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
Note
It is required to download the checkpoints first!
Out:
Computing masks....
47 # https://pypi.org/project/segment-anything-py/
48
49 # Libraries
50 import numpy as np
51 import torch
52 import matplotlib.pyplot as plt
53
54 # .. note: The notebook uses cv2 and does some alteration to the image.
55 # import cv2
56
57 # Library
58 from segment_anything import SamAutomaticMaskGenerator
59 from segment_anything import sam_model_registry
60
61
62 def show_anns(anns, ax=None):
63 if len(anns) == 0:
64 return
65 sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
66 if ax is None:
67 ax = plt.gca()
68 ax.set_autoscale_on(False)
69 polygons = []
70 color = []
71 for ann in sorted_anns:
72 m = ann['segmentation']
73 img = np.ones((m.shape[0], m.shape[1], 3))
74 color_mask = np.random.random((1, 3)).tolist()[0]
75 for i in range(3):
76 img[:,:,i] = color_mask[i]
77 ax.imshow(np.dstack((img, m*0.35)))
78
79 # Constant
80 CHECKPOINTS = {
81 'vit_b': './objects/main01/sam_vit_b_01ec64.pth', # 0.37 GB
82 'vit_l': './objects/main01/sam_vit_l_0b3195.pth', # 1.2 GB
83 'vit_h': './objects/main01/sam_vit_h_4b8939.pth', # 2.4 GB
84 }
85
86 # Variables
87 model = 'vit_b'
88
89 # Load image
90 image = plt.imread('./objects/main01/photo-1.jpg')
91
92 # Load model
93 sam = sam_model_registry[model](checkpoint=CHECKPOINTS[model])
94
95 # Create mask generator
96 mask_generator = SamAutomaticMaskGenerator(sam)
97
98 # Show
99 print("Computing masks....")
100
101 # Compute masks
102 masks = mask_generator.generate(image)
103
104 # Display
105 _, axs = plt.subplots(1, 2, figsize=(20,20), sharey=True)
106 axs[0].imshow(image)
107 axs[1].imshow(image)
108 axs[0].axis('off')
109 axs[1].axis('off')
110 show_anns(masks, ax=axs[1])
111
112 plt.tight_layout()
113 plt.show()
Total running time of the script: ( 1 minutes 29.265 seconds)