onuralpszr commited on
Commit
0dbbe5d
1 Parent(s): c90b78a

feat: ✨ color and checkbox events change annotator results dynamically feature added

Browse files
Files changed (1) hide show
  1. app.py +144 -46
app.py CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
4
  import gradio as gr
5
  import numpy as np
6
  import supervision as sv
 
7
  from PIL import Image
8
  from torch import cuda, device
9
  from ultralytics import YOLO
@@ -43,6 +44,9 @@ DESC = """
43
  </div>
44
  """ # noqa: E501 title/docs
45
 
 
 
 
46
 
47
  def load_model(img, model: str | Path = "yolov8s-seg.pt"):
48
  # Load model, get results and return detections/labels
@@ -54,7 +58,6 @@ def load_model(img, model: str | Path = "yolov8s-seg.pt"):
54
  for class_id, confidence in zip(detections.class_id, detections.confidence)
55
  ]
56
 
57
- print(labels)
58
  return detections, labels
59
 
60
 
@@ -70,10 +73,11 @@ def calculate_crop_dim(a, b):
70
  return width, height
71
 
72
 
73
- def annotator(
74
  img,
75
- model,
76
- annotators,
 
77
  colorbb,
78
  colormask,
79
  colorellipse,
@@ -83,82 +87,73 @@ def annotator(
83
  colorhalo,
84
  colortri,
85
  colordot,
86
- progress=gr.Progress(track_tqdm=True),
87
- ):
88
- """
89
- Function that changes the color of annotators
90
- Args:
91
- annotators: Icon whose color needs to be changed.
92
- color: Chosen color with which to edit the input icon in Hex.
93
- img: Input image is numpy matrix in BGR.
94
- Returns:
95
- annotators: annotated image
96
- """
97
-
98
- img = img[..., ::-1].copy() # BGR to RGB using numpy
99
-
100
- detections, labels = load_model(img, model)
101
 
102
- if "Blur" in annotators:
103
  # Apply Blur
104
  blur_annotator = sv.BlurAnnotator()
105
- img = blur_annotator.annotate(img, detections=detections)
106
 
107
- if "BoundingBox" in annotators:
108
  # Draw Boundingbox
109
  box_annotator = sv.BoundingBoxAnnotator(sv.Color.from_hex(str(colorbb)))
110
- img = box_annotator.annotate(img, detections=detections)
111
 
112
- if "Mask" in annotators:
113
  # Draw Mask
114
  mask_annotator = sv.MaskAnnotator(sv.Color.from_hex(str(colormask)))
115
- img = mask_annotator.annotate(img, detections=detections)
116
 
117
- if "Ellipse" in annotators:
118
  # Draw Ellipse
119
  ellipse_annotator = sv.EllipseAnnotator(sv.Color.from_hex(str(colorellipse)))
120
- img = ellipse_annotator.annotate(img, detections=detections)
121
 
122
- if "BoxCorner" in annotators:
123
  # Draw Box corner
124
  corner_annotator = sv.BoxCornerAnnotator(sv.Color.from_hex(str(colorbc)))
125
- img = corner_annotator.annotate(img, detections=detections)
126
 
127
- if "Circle" in annotators:
128
  # Draw Circle
129
  circle_annotator = sv.CircleAnnotator(sv.Color.from_hex(str(colorcir)))
130
- img = circle_annotator.annotate(img, detections=detections)
131
 
132
- if "Label" in annotators:
133
  # Draw Label
134
  label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)
135
  label_annotator = sv.LabelAnnotator(sv.Color.from_hex(str(colorlabel)))
136
- img = label_annotator.annotate(img, detections=detections, labels=labels)
 
 
137
 
138
- if "Pixelate" in annotators:
139
  # Apply PixelateAnnotator
140
  pixelate_annotator = sv.PixelateAnnotator()
141
- img = pixelate_annotator.annotate(img, detections=detections)
142
 
143
- if "Halo" in annotators:
144
  # Draw HaloAnnotator
145
  halo_annotator = sv.HaloAnnotator(sv.Color.from_hex(str(colorhalo)))
146
- img = halo_annotator.annotate(img, detections=detections)
147
 
148
- if "HeatMap" in annotators:
149
  # Draw HeatMapAnnotator
150
  heatmap_annotator = sv.HeatMapAnnotator()
151
- img = heatmap_annotator.annotate(img, detections=detections)
152
 
153
- if "Dot" in annotators:
154
  # Dot DotAnnotator
155
  dot_annotator = sv.DotAnnotator(sv.Color.from_hex(str(colordot)))
156
- img = dot_annotator.annotate(img, detections=detections)
157
 
158
- if "Triangle" in annotators:
159
  # Draw TriangleAnnotator
160
  tri_annotator = sv.TriangleAnnotator(sv.Color.from_hex(str(colortri)))
161
- img = tri_annotator.annotate(img, detections=detections)
162
 
163
  # crop image for the largest possible square
164
  res_img = Image.fromarray(img)
@@ -178,6 +173,54 @@ def annotator(
178
  return crop_img[..., ::-1].copy() # BGR to RGB using numpy
179
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  purple_theme = theme = gr.themes.Soft(primary_hue=gr.themes.colors.purple).set(
182
  button_primary_background_fill="*primary_600",
183
  button_primary_background_fill_hover="*primary_700",
@@ -204,7 +247,7 @@ with gr.Blocks(theme=purple_theme) as app:
204
  label="Select Model:",
205
  )
206
 
207
- annotators = gr.CheckboxGroup(
208
  choices=[
209
  "BoundingBox",
210
  "Mask",
@@ -224,7 +267,7 @@ with gr.Blocks(theme=purple_theme) as app:
224
  )
225
 
226
  gr.Markdown("## Color Picker 🎨")
227
- with gr.Row(variant="compact"):
228
  with gr.Column():
229
  colorbb = gr.ColorPicker(value="#A351FB", label="BoundingBox")
230
  colormask = gr.ColorPicker(value="#A351FB", label="Mask")
@@ -252,7 +295,7 @@ with gr.Blocks(theme=purple_theme) as app:
252
  inputs=[
253
  image_input,
254
  models,
255
- annotators,
256
  colorbb,
257
  colormask,
258
  colorellipse,
@@ -281,8 +324,63 @@ with gr.Blocks(theme=purple_theme) as app:
281
  cache_examples=False,
282
  )
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  if __name__ == "__main__":
286
  print("Starting app...")
287
  print("Dark theme is available at: http://localhost:7860/?__theme=dark")
 
288
  app.launch(debug=False)
 
4
  import gradio as gr
5
  import numpy as np
6
  import supervision as sv
7
+ from gradio import ColorPicker
8
  from PIL import Image
9
  from torch import cuda, device
10
  from ultralytics import YOLO
 
44
  </div>
45
  """ # noqa: E501 title/docs
46
 
47
+ last_detections = sv.Detections.empty()
48
+ last_labels: list[str] = []
49
+
50
 
51
  def load_model(img, model: str | Path = "yolov8s-seg.pt"):
52
  # Load model, get results and return detections/labels
 
58
  for class_id, confidence in zip(detections.class_id, detections.confidence)
59
  ]
60
 
 
61
  return detections, labels
62
 
63
 
 
73
  return width, height
74
 
75
 
76
+ def annotators(
77
  img,
78
+ last_detections,
79
+ annotators_list,
80
+ last_labels,
81
  colorbb,
82
  colormask,
83
  colorellipse,
 
87
  colorhalo,
88
  colortri,
89
  colordot,
90
+ ) -> np.ndarray:
91
+ if last_detections == sv.Detections.empty():
92
+ gr.Warning("Detection is empty please add image and annotate first")
93
+ return np.zeros()
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ if "Blur" in annotators_list:
96
  # Apply Blur
97
  blur_annotator = sv.BlurAnnotator()
98
+ img = blur_annotator.annotate(img, detections=last_detections)
99
 
100
+ if "BoundingBox" in annotators_list:
101
  # Draw Boundingbox
102
  box_annotator = sv.BoundingBoxAnnotator(sv.Color.from_hex(str(colorbb)))
103
+ img = box_annotator.annotate(img, detections=last_detections)
104
 
105
+ if "Mask" in annotators_list:
106
  # Draw Mask
107
  mask_annotator = sv.MaskAnnotator(sv.Color.from_hex(str(colormask)))
108
+ img = mask_annotator.annotate(img, detections=last_detections)
109
 
110
+ if "Ellipse" in annotators_list:
111
  # Draw Ellipse
112
  ellipse_annotator = sv.EllipseAnnotator(sv.Color.from_hex(str(colorellipse)))
113
+ img = ellipse_annotator.annotate(img, detections=last_detections)
114
 
115
+ if "BoxCorner" in annotators_list:
116
  # Draw Box corner
117
  corner_annotator = sv.BoxCornerAnnotator(sv.Color.from_hex(str(colorbc)))
118
+ img = corner_annotator.annotate(img, detections=last_detections)
119
 
120
+ if "Circle" in annotators_list:
121
  # Draw Circle
122
  circle_annotator = sv.CircleAnnotator(sv.Color.from_hex(str(colorcir)))
123
+ img = circle_annotator.annotate(img, detections=last_detections)
124
 
125
+ if "Label" in annotators_list:
126
  # Draw Label
127
  label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)
128
  label_annotator = sv.LabelAnnotator(sv.Color.from_hex(str(colorlabel)))
129
+ img = label_annotator.annotate(
130
+ img, detections=last_detections, labels=last_labels
131
+ )
132
 
133
+ if "Pixelate" in annotators_list:
134
  # Apply PixelateAnnotator
135
  pixelate_annotator = sv.PixelateAnnotator()
136
+ img = pixelate_annotator.annotate(img, detections=last_detections)
137
 
138
+ if "Halo" in annotators_list:
139
  # Draw HaloAnnotator
140
  halo_annotator = sv.HaloAnnotator(sv.Color.from_hex(str(colorhalo)))
141
+ img = halo_annotator.annotate(img, detections=last_detections)
142
 
143
+ if "HeatMap" in annotators_list:
144
  # Draw HeatMapAnnotator
145
  heatmap_annotator = sv.HeatMapAnnotator()
146
+ img = heatmap_annotator.annotate(img, detections=last_detections)
147
 
148
+ if "Dot" in annotators_list:
149
  # Dot DotAnnotator
150
  dot_annotator = sv.DotAnnotator(sv.Color.from_hex(str(colordot)))
151
+ img = dot_annotator.annotate(img, detections=last_detections)
152
 
153
+ if "Triangle" in annotators_list:
154
  # Draw TriangleAnnotator
155
  tri_annotator = sv.TriangleAnnotator(sv.Color.from_hex(str(colortri)))
156
+ img = tri_annotator.annotate(img, detections=last_detections)
157
 
158
  # crop image for the largest possible square
159
  res_img = Image.fromarray(img)
 
173
  return crop_img[..., ::-1].copy() # BGR to RGB using numpy
174
 
175
 
176
+ def annotator(
177
+ img,
178
+ model,
179
+ annotators_list,
180
+ colorbb,
181
+ colormask,
182
+ colorellipse,
183
+ colorbc,
184
+ colorcir,
185
+ colorlabel,
186
+ colorhalo,
187
+ colortri,
188
+ colordot,
189
+ progress=gr.Progress(track_tqdm=True),
190
+ ) -> np.ndarray:
191
+ """
192
+ Function that changes the color of annotators
193
+ Args:
194
+ annotators: Icon whose color needs to be changed.
195
+ color: Chosen color with which to edit the input icon in Hex.
196
+ img: Input image is numpy matrix in BGR.
197
+ Returns:
198
+ annotators: annotated image
199
+ """
200
+
201
+ img = img[..., ::-1].copy() # BGR to RGB using numpy
202
+
203
+ detections, labels = load_model(img, model)
204
+ last_detections = detections
205
+ last_labels = labels
206
+
207
+ return annotators(
208
+ img,
209
+ last_detections,
210
+ annotators_list,
211
+ last_labels,
212
+ colorbb,
213
+ colormask,
214
+ colorellipse,
215
+ colorbc,
216
+ colorcir,
217
+ colorlabel,
218
+ colorhalo,
219
+ colortri,
220
+ colordot,
221
+ )
222
+
223
+
224
  purple_theme = theme = gr.themes.Soft(primary_hue=gr.themes.colors.purple).set(
225
  button_primary_background_fill="*primary_600",
226
  button_primary_background_fill_hover="*primary_700",
 
247
  label="Select Model:",
248
  )
249
 
250
+ annotators_list = gr.CheckboxGroup(
251
  choices=[
252
  "BoundingBox",
253
  "Mask",
 
267
  )
268
 
269
  gr.Markdown("## Color Picker 🎨")
270
+ with gr.Row(variant="panel"):
271
  with gr.Column():
272
  colorbb = gr.ColorPicker(value="#A351FB", label="BoundingBox")
273
  colormask = gr.ColorPicker(value="#A351FB", label="Mask")
 
295
  inputs=[
296
  image_input,
297
  models,
298
+ annotators_list,
299
  colorbb,
300
  colormask,
301
  colorellipse,
 
324
  cache_examples=False,
325
  )
326
 
327
+ annotators_list.change(
328
+ fn=annotator,
329
+ inputs=[
330
+ image_input,
331
+ models,
332
+ annotators_list,
333
+ colorbb,
334
+ colormask,
335
+ colorellipse,
336
+ colorbc,
337
+ colorcir,
338
+ colorlabel,
339
+ colorhalo,
340
+ colortri,
341
+ colordot,
342
+ ],
343
+ outputs=image_output,
344
+ )
345
+
346
+ def change_color(color: ColorPicker):
347
+ color.change(
348
+ fn=annotator,
349
+ inputs=[
350
+ image_input,
351
+ models,
352
+ annotators_list,
353
+ colorbb,
354
+ colormask,
355
+ colorellipse,
356
+ colorbc,
357
+ colorcir,
358
+ colorlabel,
359
+ colorhalo,
360
+ colortri,
361
+ colordot,
362
+ ],
363
+ outputs=image_output,
364
+ )
365
+
366
+ colors = [
367
+ colorbb,
368
+ colormask,
369
+ colorellipse,
370
+ colorbc,
371
+ colorcir,
372
+ colorlabel,
373
+ colorhalo,
374
+ colortri,
375
+ colordot,
376
+ ]
377
+
378
+ for color in colors:
379
+ change_color(color)
380
+
381
 
382
  if __name__ == "__main__":
383
  print("Starting app...")
384
  print("Dark theme is available at: http://localhost:7860/?__theme=dark")
385
+ # app.launch(debug=False, server_name="0.0.0.0") # for local network
386
  app.launch(debug=False)