Spaces:
Sleeping
Sleeping
Add mask prompt
Browse files- app.py +54 -2
- requirements.txt +1 -0
app.py
CHANGED
@@ -3,8 +3,10 @@ import numpy as np
|
|
3 |
import gradio as gr
|
4 |
import matplotlib.pyplot as plt
|
5 |
from PIL import Image
|
|
|
6 |
from transformers import SamModel, SamProcessor
|
7 |
|
|
|
8 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
|
10 |
model = SamModel.from_pretrained('hmdliu/sidewalks-seg')
|
@@ -62,6 +64,34 @@ def segment_image_with_guidance(image, threshold, offset, x_min, y_min, x_max, y
|
|
62 |
regions = [(guidance_mask, 'Guidance'), (pred_mask, 'Sidewalks')]
|
63 |
return (image['background'], regions), Image.open('prob.png')
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
with gr.Blocks() as demo:
|
66 |
with gr.Tab('Baseline'):
|
67 |
with gr.Row():
|
@@ -78,10 +108,10 @@ with gr.Blocks() as demo:
|
|
78 |
t1_pred = gr.AnnotatedImage(color_map={'Sidewalks': '#0000FF'}, label='Prediction')
|
79 |
with gr.Column():
|
80 |
t1_prob_map = gr.Image(type='pil', label='Probability Map')
|
81 |
-
with gr.Tab('Mask Guidance'):
|
82 |
with gr.Row():
|
83 |
with gr.Column():
|
84 |
-
t2_input = gr.ImageEditor(type='pil', crop_size='
|
85 |
brush=gr.Brush(default_size='5', color_mode='fixed'),
|
86 |
sources=['upload'], transforms=[])
|
87 |
with gr.Row():
|
@@ -96,6 +126,23 @@ with gr.Blocks() as demo:
|
|
96 |
t2_pred = gr.AnnotatedImage(color_map={'Guidance': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
|
97 |
with gr.Column():
|
98 |
t2_prob_map = gr.Image(type='pil', label='Probability Map')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
t1_segment.click(
|
100 |
segment_image,
|
101 |
inputs=[t1_input, t1_slider, t1_x_min, t1_y_min, t1_x_max, t1_y_max],
|
@@ -106,4 +153,9 @@ with gr.Blocks() as demo:
|
|
106 |
inputs=[t2_input, t2_thresh, t2_offset, t2_x_min, t2_y_min, t2_x_max, t2_y_max],
|
107 |
outputs=[t2_pred, t2_prob_map]
|
108 |
)
|
|
|
|
|
|
|
|
|
|
|
109 |
demo.launch(debug=True, show_error=True)
|
|
|
3 |
import gradio as gr
|
4 |
import matplotlib.pyplot as plt
|
5 |
from PIL import Image
|
6 |
+
from torchvision.transforms import ToTensor
|
7 |
from transformers import SamModel, SamProcessor
|
8 |
|
9 |
+
to_tensor = ToTensor()
|
10 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
11 |
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
|
12 |
model = SamModel.from_pretrained('hmdliu/sidewalks-seg')
|
|
|
64 |
regions = [(guidance_mask, 'Guidance'), (pred_mask, 'Sidewalks')]
|
65 |
return (image['background'], regions), Image.open('prob.png')
|
66 |
|
67 |
+
def segment_image_with_prompt(image, threshold, x_min, y_min, x_max, y_max):
|
68 |
+
# tolerate TIFF image input
|
69 |
+
image['background'].save('image.png')
|
70 |
+
# init input data
|
71 |
+
img = Image.open('image.png').convert('RGB')
|
72 |
+
mask = (np.max(np.array(image['layers'][0]), axis=2) != 0)
|
73 |
+
mask_prompt = to_tensor(mask).float()
|
74 |
+
box_prompt = [[[x_min, y_min, x_max, y_max]]]
|
75 |
+
inputs = processor(img, input_boxes=box_prompt,
|
76 |
+
input_masks=mask_prompt, return_tensors='pt')
|
77 |
+
# make prediction
|
78 |
+
outputs = model(pixel_values=inputs['pixel_values'].to(device),
|
79 |
+
input_boxes=inputs['input_boxes'].to(device),
|
80 |
+
input_masks=mask_prompt.to(device),
|
81 |
+
multimask_output=False)
|
82 |
+
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
|
83 |
+
pred_mask = (prob_map > threshold).float().numpy()
|
84 |
+
# visualize results
|
85 |
+
plt.figure(figsize=(8, 8))
|
86 |
+
plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest')
|
87 |
+
plt.axis('off')
|
88 |
+
plt.tight_layout()
|
89 |
+
plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
|
90 |
+
plt.close()
|
91 |
+
# post-processing
|
92 |
+
regions = [(mask, 'Prompt'), (pred_mask, 'Sidewalks')]
|
93 |
+
return (image['background'], regions), Image.open('prob.png')
|
94 |
+
|
95 |
with gr.Blocks() as demo:
|
96 |
with gr.Tab('Baseline'):
|
97 |
with gr.Row():
|
|
|
108 |
t1_pred = gr.AnnotatedImage(color_map={'Sidewalks': '#0000FF'}, label='Prediction')
|
109 |
with gr.Column():
|
110 |
t1_prob_map = gr.Image(type='pil', label='Probability Map')
|
111 |
+
with gr.Tab('Mask Guidance (Best)'):
|
112 |
with gr.Row():
|
113 |
with gr.Column():
|
114 |
+
t2_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image',
|
115 |
brush=gr.Brush(default_size='5', color_mode='fixed'),
|
116 |
sources=['upload'], transforms=[])
|
117 |
with gr.Row():
|
|
|
126 |
t2_pred = gr.AnnotatedImage(color_map={'Guidance': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
|
127 |
with gr.Column():
|
128 |
t2_prob_map = gr.Image(type='pil', label='Probability Map')
|
129 |
+
with gr.Tab('Mask Prompt'):
|
130 |
+
with gr.Row():
|
131 |
+
with gr.Column():
|
132 |
+
t3_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image',
|
133 |
+
brush=gr.Brush(default_size='5', color_mode='fixed'),
|
134 |
+
sources=['upload'], transforms=[])
|
135 |
+
with gr.Row():
|
136 |
+
t3_x_min = gr.Textbox(value=0, label='x_min')
|
137 |
+
t3_y_min = gr.Textbox(value=0, label='y_min')
|
138 |
+
t3_x_max = gr.Textbox(value=256, label='x_max')
|
139 |
+
t3_y_max = gr.Textbox(value=256, label='y_max')
|
140 |
+
t3_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
|
141 |
+
t3_segment = gr.Button('Segment')
|
142 |
+
with gr.Column():
|
143 |
+
t3_pred = gr.AnnotatedImage(color_map={'Prompt': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
|
144 |
+
with gr.Column():
|
145 |
+
t3_prob_map = gr.Image(type='pil', label='Probability Map')
|
146 |
t1_segment.click(
|
147 |
segment_image,
|
148 |
inputs=[t1_input, t1_slider, t1_x_min, t1_y_min, t1_x_max, t1_y_max],
|
|
|
153 |
inputs=[t2_input, t2_thresh, t2_offset, t2_x_min, t2_y_min, t2_x_max, t2_y_max],
|
154 |
outputs=[t2_pred, t2_prob_map]
|
155 |
)
|
156 |
+
t3_segment.click(
|
157 |
+
segment_image_with_prompt,
|
158 |
+
inputs=[t3_input, t3_thresh, t3_x_min, t3_y_min, t3_x_max, t3_y_max],
|
159 |
+
outputs=[t3_pred, t3_prob_map]
|
160 |
+
)
|
161 |
demo.launch(debug=True, show_error=True)
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
torch
|
|
|
2 |
matplotlib
|
3 |
transformers
|
|
|
1 |
torch
|
2 |
+
torchvision
|
3 |
matplotlib
|
4 |
transformers
|