01. Segment anything

The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.



Note

It is required to download the checkpoints first!

plot main01 segment anything

Out:

Computing masks....

 47 # https://pypi.org/project/segment-anything-py/
 48
 49 # Libraries
 50 import numpy as np
 51 import torch
 52 import matplotlib.pyplot as plt
 53
 54 # .. note: The notebook uses cv2 and does some alteration to the image.
 55 # import cv2
 56
 57 # Library
 58 from segment_anything import SamAutomaticMaskGenerator
 59 from segment_anything import sam_model_registry
 60
 61
 62 def show_anns(anns, ax=None):
 63     if len(anns) == 0:
 64         return
 65     sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
 66     if ax is None:
 67         ax = plt.gca()
 68     ax.set_autoscale_on(False)
 69     polygons = []
 70     color = []
 71     for ann in sorted_anns:
 72         m = ann['segmentation']
 73         img = np.ones((m.shape[0], m.shape[1], 3))
 74         color_mask = np.random.random((1, 3)).tolist()[0]
 75         for i in range(3):
 76             img[:,:,i] = color_mask[i]
 77         ax.imshow(np.dstack((img, m*0.35)))
 78
 79 # Constant
 80 CHECKPOINTS = {
 81     'vit_b': './objects/main01/sam_vit_b_01ec64.pth', # 0.37 GB
 82     'vit_l': './objects/main01/sam_vit_l_0b3195.pth', # 1.2 GB
 83     'vit_h': './objects/main01/sam_vit_h_4b8939.pth', # 2.4 GB
 84 }
 85
 86 # Variables
 87 model = 'vit_b'
 88
 89 # Load image
 90 image = plt.imread('./objects/main01/photo-1.jpg')
 91
 92 # Load model
 93 sam = sam_model_registry[model](checkpoint=CHECKPOINTS[model])
 94
 95 # Create mask generator
 96 mask_generator = SamAutomaticMaskGenerator(sam)
 97
 98 # Show
 99 print("Computing masks....")
100
101 # Compute masks
102 masks = mask_generator.generate(image)
103
104 # Display
105 _, axs = plt.subplots(1, 2, figsize=(20,20), sharey=True)
106 axs[0].imshow(image)
107 axs[1].imshow(image)
108 axs[0].axis('off')
109 axs[1].axis('off')
110 show_anns(masks, ax=axs[1])
111
112 plt.tight_layout()
113 plt.show()

Total running time of the script: ( 1 minutes 29.265 seconds)

Gallery generated by Sphinx-Gallery