jadechoghari commited on
Commit
30d7a78
1 Parent(s): ecbb73b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +95 -0
README.md CHANGED
@@ -27,6 +27,101 @@ Our method leverages the pre-trained SAM model with only marginal parameter incr
27
 
28
  <img width="1096" alt="image" src="figures/architecture.jpg">
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  ## Comparison of computational requirements
32
  <img width="720" alt="image" src='figures/Computational requirements.PNG'>
 
27
 
28
  <img width="1096" alt="image" src="figures/architecture.jpg">
29
 
30
+ **Disclaimer**: Content from **this** model card has been written by the Hugging Face team, and parts of it were copy pasted from the original [SAM model card](https://github.com/facebookresearch/segment-anything).
31
+
32
+ # Model Details
33
+
34
+ The RobustSAM model is made up of 3 modules:
35
+ - The `VisionEncoder`: a VIT based image encoder. It computes the image embeddings using attention on patches of the image. Relative Positional Embedding is used.
36
+ - The `PromptEncoder`: generates embeddings for points and bounding boxes
37
+ - The `MaskDecoder`: a two-ways transformer which performs cross attention between the image embedding and the point embeddings (->) and between the point embeddings and the image embeddings. The outputs are fed
38
+ - The `Neck`: predicts the output masks based on the contextualized masks produced by the `MaskDecoder`.
39
+ # Usage
40
+
41
+
42
+ ## Prompted-Mask-Generation
43
+
44
+ ```python
45
+ from PIL import Image
46
+ import requests
47
+ from transformers import AutoProcessor, AutoModelForMaskGeneration
48
+
49
+ # load the RobustSAM model and processor
50
+ processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
51
+ model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-base")
52
+
53
+ # load an image from a url
54
+ img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
55
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
56
+
57
+ # we define input points (2D localization of an object in the image)
58
+ input_points = [[[450, 600]]] # example point
59
+
60
+ ```
61
+
62
+
63
+ ```python
64
+ # process the image and input points
65
+ inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
66
+
67
+ # generate masks using the model
68
+ with torch.no_grad():
69
+ outputs = model(**inputs)
70
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
71
+ scores = outputs.iou_scores
72
+
73
+ ```
74
+ Among other arguments to generate masks, you can pass 2D locations on the approximate position of your object of interest, a bounding box wrapping the object of interest (the format should be x, y coordinate of the top right and bottom left point of the bounding box), a segmentation mask. At this time of writing, passing a text as input is not supported by the official model according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844).
75
+ For more details, refer to this notebook, which shows a walk throught of how to use the model, with a visual example!
76
+
77
+ ## Automatic-Mask-Generation
78
+
79
+ The model can be used for generating segmentation masks in a "zero-shot" fashion, given an input image. The model is automatically prompt with a grid of `1024` points
80
+ which are all fed to the model.
81
+
82
+ The pipeline is made for automatic mask generation. The following snippet demonstrates how easy you can run it (on any device! Simply feed the appropriate `points_per_batch` argument)
83
+ ```python
84
+ from transformers import pipeline
85
+
86
+ # initialize the pipeline for mask generation
87
+ generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-base", device=0, points_per_batch=256)
88
+
89
+ image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
90
+ outputs = generator(image_url, points_per_batch=256)
91
+ ```
92
+ Now to display the generated mask on the image:
93
+ ```python
94
+ import matplotlib.pyplot as plt
95
+ from PIL import Image
96
+ import numpy as np
97
+
98
+ # simple function to display the mask
99
+ def show_mask(mask, ax, random_color=False):
100
+ if random_color:
101
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
102
+ else:
103
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
104
+
105
+ # get the height and width from the mask
106
+ h, w = mask.shape[-2:]
107
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
108
+ ax.imshow(mask_image)
109
+
110
+ # display the original image
111
+ plt.imshow(np.array(raw_image))
112
+ ax = plt.gca()
113
+
114
+ # loop through the masks and display each one
115
+ for mask in outputs["masks"]:
116
+ show_mask(mask, ax=ax, random_color=True)
117
+
118
+ plt.axis("off")
119
+
120
+ # show the image with the masks
121
+ plt.show()
122
+ ```
123
+
124
+
125
 
126
  ## Comparison of computational requirements
127
  <img width="720" alt="image" src='figures/Computational requirements.PNG'>