AIBoy1993 commited on
Commit
9bc47e3
1 Parent(s): 78bc5df

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
9
+
10
+
11
+ models = {
12
+ 'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
13
+ 'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
14
+ 'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
15
+ }
16
+
17
+ def inference(device, model_type, input_img, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area,
18
+ stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh):
19
+ sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
20
+ mask_generator = SamAutomaticMaskGenerator(
21
+ sam,
22
+ points_per_side=points_per_side,
23
+ pred_iou_thresh=pred_iou_thresh,
24
+ stability_score_thresh=stability_score_thresh,
25
+ stability_score_offset=stability_score_offset,
26
+ box_nms_thresh=box_nms_thresh,
27
+ crop_n_layers=crop_n_layers,
28
+ crop_nms_thresh=crop_nms_thresh,
29
+ crop_overlap_ratio=512 / 1500,
30
+ crop_n_points_downscale_factor=1,
31
+ point_grids=None,
32
+ min_mask_region_area=min_mask_region_area,
33
+ output_mode='binary_mask'
34
+ )
35
+
36
+ masks = mask_generator.generate(input_img)
37
+ sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
38
+
39
+ mask_all = np.ones((input_img.shape[0], input_img.shape[1], 3))
40
+ for ann in sorted_anns:
41
+ m = ann['segmentation']
42
+ color_mask = np.random.random((1, 3)).tolist()[0]
43
+ for i in range(3):
44
+ mask_all[m==True, i] = color_mask[i]
45
+ result = input_img / 255 * 0.3 + mask_all * 0.7
46
+
47
+ return result, mask_all
48
+
49
+
50
+
51
+ with gr.Blocks() as demo:
52
+ with gr.Row():
53
+ gr.Markdown(
54
+ '''# Segment Anything!🚀
55
+ 分割一切!CV的GPT-3时刻!
56
+ [**官方网址**](https://segment-anything.com/)
57
+ '''
58
+ )
59
+ with gr.Row():
60
+ # 选择模型类型
61
+ model_type = gr.Dropdown(["vit_b", "vit_l", "vit_h"], value='vit_b', label="选择模型")
62
+ # 选择device
63
+ device = gr.Dropdown(["cpu", "cuda"], value='cuda', label="选择你的硬件")
64
+
65
+ # 参数
66
+ with gr.Accordion(label='参数调整', open=False):
67
+ with gr.Row():
68
+ points_per_side = gr.Number(value=32, label="points_per_side", precision=0,
69
+ info='''The number of points to be sampled along one side of the image. The total
70
+ number of points is points_per_side**2.''')
71
+ pred_iou_thresh = gr.Slider(value=0.88, minimum=0, maximum=1.0, step=0.01, label="pred_iou_thresh",
72
+ info='''A filtering threshold in [0,1], using the model's predicted mask quality.''')
73
+ stability_score_thresh = gr.Slider(value=0.95, minimum=0, maximum=1.0, step=0.01, label="stability_score_thresh",
74
+ info='''A filtering threshold in [0,1], using the stability of the mask under
75
+ changes to the cutoff used to binarize the model's mask predictions.''')
76
+ min_mask_region_area = gr.Number(value=0, label="min_mask_region_area", precision=0,
77
+ info='''If >0, postprocessing will be applied to remove disconnected regions
78
+ and holes in masks with area smaller than min_mask_region_area.''')
79
+ with gr.Row():
80
+ stability_score_offset = gr.Number(value=1, label="stability_score_offset",
81
+ info='''The amount to shift the cutoff when calculated the stability score.''')
82
+ box_nms_thresh = gr.Slider(value=0.7, minimum=0, maximum=1.0, step=0.01, label="box_nms_thresh",
83
+ info='''The box IoU cutoff used by non-maximal ression to filter duplicate masks.''')
84
+ crop_n_layers = gr.Number(value=0, label="crop_n_layers", precision=0,
85
+ info='''If >0, mask prediction will be run again on crops of the image.
86
+ Sets the number of layers to run, where each layer has 2**i_layer number of image crops.''')
87
+ crop_nms_thresh = gr.Slider(value=0.7, minimum=0, maximum=1.0, step=0.01, label="crop_nms_thresh",
88
+ info='''The box IoU cutoff used by non-maximal suppression to filter duplicate
89
+ masks between different crops.''')
90
+
91
+ # 显示图片
92
+ with gr.Row().style(equal_height=True):
93
+ with gr.Column():
94
+ input_image = gr.Image(type="numpy")
95
+ with gr.Row():
96
+ button = gr.Button("Auto!")
97
+ with gr.Tab(label='原图+mask'):
98
+ image_output = gr.Image(type='numpy')
99
+ with gr.Tab(label='Mask'):
100
+ mask_output = gr.Image(type='numpy')
101
+
102
+ gr.Examples(
103
+ examples=[os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"),
104
+ os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"),
105
+ os.path.join(os.path.dirname(__file__), "./images/1.jpg"),
106
+ os.path.join(os.path.dirname(__file__), "./images/2.jpg"),
107
+ os.path.join(os.path.dirname(__file__), "./images/3.jpg"),
108
+ os.path.join(os.path.dirname(__file__), "./images/4.jpg"),
109
+ os.path.join(os.path.dirname(__file__), "./images/5.jpg"),
110
+ os.path.join(os.path.dirname(__file__), "./images/6.jpg"),
111
+ os.path.join(os.path.dirname(__file__), "./images/7.jpg"),
112
+ os.path.join(os.path.dirname(__file__), "./images/8.jpg"),
113
+ ],
114
+ inputs=input_image,
115
+ outputs=image_output,
116
+ )
117
+
118
+
119
+ # 按钮交互
120
+ button.click(inference, inputs=[device, model_type, input_image, points_per_side, pred_iou_thresh,
121
+ stability_score_thresh, min_mask_region_area, stability_score_offset, box_nms_thresh,
122
+ crop_n_layers, crop_nms_thresh],
123
+ outputs=[image_output, mask_output])
124
+
125
+
126
+
127
+ demo.launch(debug=True)
128
+
129
+
130
+