Norod78 commited on
Commit
df0749a
1 Parent(s): dbbe955

Support CPU

Browse files
Files changed (1) hide show
  1. app.py +178 -153
app.py CHANGED
@@ -1,32 +1,55 @@
1
- import os
2
- from unittest.mock import patch
3
- import spaces
4
  import gradio as gr
5
  from transformers import AutoProcessor, AutoModelForCausalLM
6
- from transformers.dynamic_module_utils import get_imports
7
- import torch
8
  import requests
9
- from PIL import Image, ImageDraw
10
- import random
11
- import numpy as np
 
12
  import matplotlib.pyplot as plt
13
  import matplotlib.patches as patches
14
- import cv2
15
- import io
16
 
17
- def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]:
 
 
 
 
 
 
 
 
 
18
  if not str(filename).endswith("/modeling_florence2.py"):
19
  return get_imports(filename)
20
  imports = get_imports(filename)
21
  imports.remove("flash_attn")
22
  return imports
23
 
24
- with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
25
- model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to("cuda").eval()
 
 
 
 
 
 
 
 
 
 
 
26
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
 
 
 
 
 
 
 
27
 
28
- colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
29
- 'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
30
 
31
  def fig_to_pil(fig):
32
  buf = io.BytesIO()
@@ -40,21 +63,20 @@ def run_example(task_prompt, image, text_input=None):
40
  prompt = task_prompt
41
  else:
42
  prompt = task_prompt + text_input
43
- inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
44
- with torch.inference_mode():
45
- generated_ids = model.generate(
46
- input_ids=inputs["input_ids"],
47
- pixel_values=inputs["pixel_values"],
48
- max_new_tokens=1024,
49
- early_stopping=False,
50
- do_sample=False,
51
- num_beams=3,
52
- )
53
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
54
  parsed_answer = processor.post_process_generation(
55
  generated_text,
56
  task=task_prompt,
57
- image_size=(image.size[0], image.size[1])
58
  )
59
  return parsed_answer
60
 
@@ -65,161 +87,164 @@ def plot_bbox(image, data):
65
  x1, y1, x2, y2 = bbox
66
  rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
67
  ax.add_patch(rect)
68
- plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='indigo', alpha=0.5))
69
  ax.axis('off')
70
- return fig_to_pil(fig)
71
 
72
  def draw_polygons(image, prediction, fill_mask=False):
73
- fig, ax = plt.subplots()
74
- ax.imshow(image)
75
  scale = 1
76
  for polygons, label in zip(prediction['polygons'], prediction['labels']):
77
  color = random.choice(colormap)
78
  fill_color = random.choice(colormap) if fill_mask else None
79
  for _polygon in polygons:
80
  _polygon = np.array(_polygon).reshape(-1, 2)
81
- if _polygon.shape[0] < 3:
82
  print('Invalid polygon:', _polygon)
83
  continue
84
  _polygon = (_polygon * scale).reshape(-1).tolist()
85
- if len(_polygon) % 2 != 0:
86
- print('Invalid polygon:', _polygon)
87
- continue
88
- polygon_points = np.array(_polygon).reshape(-1, 2)
89
  if fill_mask:
90
- polygon = patches.Polygon(polygon_points, edgecolor=color, facecolor=fill_color, linewidth=2)
91
  else:
92
- polygon = patches.Polygon(polygon_points, edgecolor=color, fill=False, linewidth=2)
93
- ax.add_patch(polygon)
94
- plt.text(polygon_points[0, 0], polygon_points[0, 1], label, color='white', fontsize=8, bbox=dict(facecolor=color, alpha=0.5))
95
- ax.axis('off')
96
- return fig_to_pil(fig)
 
 
 
 
 
 
 
97
 
98
  def draw_ocr_bboxes(image, prediction):
99
- fig, ax = plt.subplots()
100
- ax.imshow(image)
101
  scale = 1
 
102
  bboxes, labels = prediction['quad_boxes'], prediction['labels']
103
  for box, label in zip(bboxes, labels):
104
  color = random.choice(colormap)
105
  new_box = (np.array(box) * scale).tolist()
106
- polygon = patches.Polygon(new_box, edgecolor=color, fill=False, linewidth=3)
107
- ax.add_patch(polygon)
108
- plt.text(new_box[0], new_box[1], label, color='white', fontsize=8, bbox=dict(facecolor=color, alpha=0.5))
109
- ax.axis('off')
110
- return fig_to_pil(fig)
111
-
112
-
113
- @spaces.GPU(duration=120)
114
- def process_video(input_video_path, task_prompt):
115
- cap = cv2.VideoCapture(input_video_path)
116
- if not cap.isOpened():
117
- print("Error: Can't open the video file.")
118
- return
119
-
120
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
121
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
122
- fps = cap.get(cv2.CAP_PROP_FPS)
123
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
124
- out = cv2.VideoWriter("output_vid.mp4", fourcc, fps, (frame_width, frame_height))
125
-
126
- while cap.isOpened():
127
- ret, frame = cap.read()
128
- if not ret:
129
- break
130
-
131
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
132
- pil_image = Image.fromarray(frame_rgb)
133
-
134
- result = run_example(task_prompt, pil_image)
135
-
136
- if task_prompt == "<OD>":
137
- processed_image = plot_bbox(pil_image, result['<OD>'])
138
- elif task_prompt == "<DENSE_REGION_CAPTION>":
139
- processed_image = plot_bbox(pil_image, result['<DENSE_REGION_CAPTION>'])
140
- else:
141
- processed_image = pil_image
142
-
143
- processed_frame = cv2.cvtColor(np.array(processed_image), cv2.COLOR_RGB2BGR)
144
- out.write(processed_frame)
145
-
146
- cap.release()
147
- out.release()
148
- cv2.destroyAllWindows()
149
- return "output_vid.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  css = """
152
- #output {
153
- min-height: 100px;
154
- overflow: auto;
155
- border: 1px solid #ccc;
156
- }
157
  """
158
 
159
  with gr.Blocks(css=css) as demo:
160
- gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
161
- with gr.Tab(label="Image"):
162
  with gr.Row():
163
  with gr.Column():
164
- input_img = gr.Image(label="Input Picture", type="pil")
165
- task_radio = gr.Radio(
166
- ["Caption", "Detailed Caption", "More Detailed Caption", "Caption to Phrase Grounding",
167
- "Object Detection", "Dense Region Caption", "Region Proposal", "Referring Expression Segmentation",
168
- "Region to Segmentation", "Open Vocabulary Detection", "Region to Category", "Region to Description",
169
- "OCR", "OCR with Region"],
170
- label="Task", value="Caption"
171
- )
172
- text_input = gr.Textbox(label="Text Input (is Optional)", visible=False)
173
  submit_btn = gr.Button(value="Submit")
174
  with gr.Column():
175
- output_text = gr.Textbox(label="Results")
176
- output_image = gr.Image(label="Image", type="pil")
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- with gr.Tab(label="Video"):
179
- with gr.Row():
180
- with gr.Column():
181
- input_video = gr.Video(label="Video")
182
- video_task_radio = gr.Radio(
183
- ["Object Detection", "Dense Region Caption"],
184
- label="Video Task", value="Object Detection"
185
- )
186
- video_submit_btn = gr.Button(value="Process Video")
187
- with gr.Column():
188
- output_video = gr.Video(label="Video")
189
-
190
- def update_text_input(task):
191
- return gr.update(visible=task in ["Caption to Phrase Grounding", "Referring Expression Segmentation",
192
- "Region to Segmentation", "Open Vocabulary Detection", "Region to Category",
193
- "Region to Description"])
194
-
195
- task_radio.change(fn=update_text_input, inputs=task_radio, outputs=text_input)
196
-
197
- def process_image(image, task, text):
198
- task_mapping = {
199
- "Caption": ("<CAPTION>", lambda result: (result['<CAPTION>'], image)),
200
- "Detailed Caption": ("<DETAILED_CAPTION>", lambda result: (result['<DETAILED_CAPTION>'], image)),
201
- "More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result['<MORE_DETAILED_CAPTION>'], image)),
202
- "Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), plot_bbox(image, result['<CAPTION_TO_PHRASE_GROUNDING>']))),
203
- "Object Detection": ("<OD>", lambda result: (str(result['<OD>']), plot_bbox(image, result['<OD>']))),
204
- "Dense Region Caption": ("<DENSE_REGION_CAPTION>", lambda result: (str(result['<DENSE_REGION_CAPTION>']), plot_bbox(image, result['<DENSE_REGION_CAPTION>']))),
205
- "Region Proposal": ("<REGION_PROPOSAL>", lambda result: (str(result['<REGION_PROPOSAL>']), plot_bbox(image, result['<REGION_PROPOSAL>']))),
206
- "Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True))),
207
- "Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True))),
208
- "Open Vocabulary Detection": ("<OPEN_VOCABULARY_DETECTION>", lambda result: (str(convert_to_od_format(result['<OPEN_VOCABULARY_DETECTION>'])), plot_bbox(image, convert_to_od_format(result['<OPEN_VOCABULARY_DETECTION>'])))),
209
- "Region to Category": ("<REGION_TO_CATEGORY>", lambda result: (result['<REGION_TO_CATEGORY>'], image)),
210
- "Region to Description": ("<REGION_TO_DESCRIPTION>", lambda result: (result['<REGION_TO_DESCRIPTION>'], image)),
211
- "OCR": ("<OCR>", lambda result: (result['<OCR>'], image)),
212
- "OCR with Region": ("<OCR_WITH_REGION>", lambda result: (str(result['<OCR_WITH_REGION>']), draw_ocr_bboxes(image, result['<OCR_WITH_REGION>']))),
213
- }
214
-
215
- if task in task_mapping:
216
- prompt, process_func = task_mapping[task]
217
- result = run_example(prompt, image, text)
218
- return process_func(result)
219
- else:
220
- return "", image
221
-
222
- submit_btn.click(fn=process_image, inputs=[input_img, task_radio, text_input], outputs=[output_text, output_image])
223
- video_submit_btn.click(fn=process_video, inputs=[input_video, video_task_radio], outputs=output_video)
224
-
225
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForCausalLM
3
+ import spaces
4
+
5
  import requests
6
+ import copy
7
+
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import io
10
  import matplotlib.pyplot as plt
11
  import matplotlib.patches as patches
 
 
12
 
13
+ import random
14
+ import numpy as np
15
+
16
+ import os
17
+ from unittest.mock import patch
18
+ from transformers import AutoModelForCausalLM, AutoProcessor
19
+ from transformers.dynamic_module_utils import get_imports
20
+
21
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
22
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
23
  if not str(filename).endswith("/modeling_florence2.py"):
24
  return get_imports(filename)
25
  imports = get_imports(filename)
26
  imports.remove("flash_attn")
27
  return imports
28
 
29
+
30
+ @spaces.GPU
31
+ def get_device_type():
32
+ import torch
33
+ return "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ model_id = 'microsoft/Florence-2-base-ft'
36
+
37
+ import subprocess
38
+ device = get_device_type()
39
+ if (device == "cuda"):
40
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
41
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
42
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
43
+ else:
44
+ #https://huggingface.co/microsoft/Florence-2-large-ft/discussions/4
45
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
46
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
47
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
48
+
49
+ DESCRIPTION = "# [Florence-2 base-ft Demo with CPU inference support](https://huggingface.co/microsoft/Florence-2-base-ft)"
50
 
51
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
52
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
53
 
54
  def fig_to_pil(fig):
55
  buf = io.BytesIO()
 
63
  prompt = task_prompt
64
  else:
65
  prompt = task_prompt + text_input
66
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
67
+ generated_ids = model.generate(
68
+ input_ids=inputs["input_ids"],
69
+ pixel_values=inputs["pixel_values"],
70
+ max_new_tokens=1024,
71
+ early_stopping=False,
72
+ do_sample=False,
73
+ num_beams=3,
74
+ )
 
75
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
76
  parsed_answer = processor.post_process_generation(
77
  generated_text,
78
  task=task_prompt,
79
+ image_size=(image.width, image.height)
80
  )
81
  return parsed_answer
82
 
 
87
  x1, y1, x2, y2 = bbox
88
  rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
89
  ax.add_patch(rect)
90
+ plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
91
  ax.axis('off')
92
+ return fig
93
 
94
  def draw_polygons(image, prediction, fill_mask=False):
95
+ draw = ImageDraw.Draw(image)
 
96
  scale = 1
97
  for polygons, label in zip(prediction['polygons'], prediction['labels']):
98
  color = random.choice(colormap)
99
  fill_color = random.choice(colormap) if fill_mask else None
100
  for _polygon in polygons:
101
  _polygon = np.array(_polygon).reshape(-1, 2)
102
+ if len(_polygon) < 3:
103
  print('Invalid polygon:', _polygon)
104
  continue
105
  _polygon = (_polygon * scale).reshape(-1).tolist()
 
 
 
 
106
  if fill_mask:
107
+ draw.polygon(_polygon, outline=color, fill=fill_color)
108
  else:
109
+ draw.polygon(_polygon, outline=color)
110
+ draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
111
+ return image
112
+
113
+ def convert_to_od_format(data):
114
+ bboxes = data.get('bboxes', [])
115
+ labels = data.get('bboxes_labels', [])
116
+ od_results = {
117
+ 'bboxes': bboxes,
118
+ 'labels': labels
119
+ }
120
+ return od_results
121
 
122
  def draw_ocr_bboxes(image, prediction):
 
 
123
  scale = 1
124
+ draw = ImageDraw.Draw(image)
125
  bboxes, labels = prediction['quad_boxes'], prediction['labels']
126
  for box, label in zip(bboxes, labels):
127
  color = random.choice(colormap)
128
  new_box = (np.array(box) * scale).tolist()
129
+ draw.polygon(new_box, width=3, outline=color)
130
+ draw.text((new_box[0]+8, new_box[1]+2),
131
+ "{}".format(label),
132
+ align="right",
133
+ fill=color)
134
+ return image
135
+
136
+ def process_image(image, task_prompt, text_input=None):
137
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
138
+ if task_prompt == 'Caption':
139
+ task_prompt = '<CAPTION>'
140
+ result = run_example(task_prompt, image)
141
+ return result, None
142
+ elif task_prompt == 'Detailed Caption':
143
+ task_prompt = '<DETAILED_CAPTION>'
144
+ result = run_example(task_prompt, image)
145
+ return result, None
146
+ elif task_prompt == 'More Detailed Caption':
147
+ task_prompt = '<MORE_DETAILED_CAPTION>'
148
+ result = run_example(task_prompt, image)
149
+ return result, None
150
+ elif task_prompt == 'Object Detection':
151
+ task_prompt = '<OD>'
152
+ results = run_example(task_prompt, image)
153
+ fig = plot_bbox(image, results['<OD>'])
154
+ return results, fig_to_pil(fig)
155
+ elif task_prompt == 'Dense Region Caption':
156
+ task_prompt = '<DENSE_REGION_CAPTION>'
157
+ results = run_example(task_prompt, image)
158
+ fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
159
+ return results, fig_to_pil(fig)
160
+ elif task_prompt == 'Region Proposal':
161
+ task_prompt = '<REGION_PROPOSAL>'
162
+ results = run_example(task_prompt, image)
163
+ fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
164
+ return results, fig_to_pil(fig)
165
+ elif task_prompt == 'Caption to Phrase Grounding':
166
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
167
+ results = run_example(task_prompt, image, text_input)
168
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
169
+ return results, fig_to_pil(fig)
170
+ elif task_prompt == 'Referring Expression Segmentation':
171
+ task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
172
+ results = run_example(task_prompt, image, text_input)
173
+ output_image = copy.deepcopy(image)
174
+ output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
175
+ return results, output_image
176
+ elif task_prompt == 'Region to Segmentation':
177
+ task_prompt = '<REGION_TO_SEGMENTATION>'
178
+ results = run_example(task_prompt, image, text_input)
179
+ output_image = copy.deepcopy(image)
180
+ output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
181
+ return results, output_image
182
+ elif task_prompt == 'Open Vocabulary Detection':
183
+ task_prompt = '<OPEN_VOCABULARY_DETECTION>'
184
+ results = run_example(task_prompt, image, text_input)
185
+ bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
186
+ fig = plot_bbox(image, bbox_results)
187
+ return results, fig_to_pil(fig)
188
+ elif task_prompt == 'Region to Category':
189
+ task_prompt = '<REGION_TO_CATEGORY>'
190
+ results = run_example(task_prompt, image, text_input)
191
+ return results, None
192
+ elif task_prompt == 'Region to Description':
193
+ task_prompt = '<REGION_TO_DESCRIPTION>'
194
+ results = run_example(task_prompt, image, text_input)
195
+ return results, None
196
+ elif task_prompt == 'OCR':
197
+ task_prompt = '<OCR>'
198
+ result = run_example(task_prompt, image)
199
+ return result, None
200
+ elif task_prompt == 'OCR with Region':
201
+ task_prompt = '<OCR_WITH_REGION>'
202
+ results = run_example(task_prompt, image)
203
+ output_image = copy.deepcopy(image)
204
+ output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
205
+ return results, output_image
206
+ else:
207
+ return "", None # Return empty string and None for unknown task prompts
208
 
209
  css = """
210
+ #output {
211
+ height: 500px;
212
+ overflow: auto;
213
+ border: 1px solid #ccc;
214
+ }
215
  """
216
 
217
  with gr.Blocks(css=css) as demo:
218
+ gr.Markdown(DESCRIPTION)
219
+ with gr.Tab(label="Florence-2 Image Captioning"):
220
  with gr.Row():
221
  with gr.Column():
222
+ input_img = gr.Image(label="Input Picture")
223
+ task_prompt = gr.Dropdown(choices=[
224
+ 'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
225
+ 'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
226
+ 'Referring Expression Segmentation', 'Region to Segmentation',
227
+ 'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
228
+ 'OCR', 'OCR with Region'
229
+ ], label="Task Prompt", value= 'Caption')
230
+ text_input = gr.Textbox(label="Text Input (optional)")
231
  submit_btn = gr.Button(value="Submit")
232
  with gr.Column():
233
+ output_text = gr.Textbox(label="Output Text")
234
+ output_img = gr.Image(label="Output Image")
235
+
236
+ gr.Examples(
237
+ examples=[
238
+ ["image1.jpg", 'Object Detection'],
239
+ ["image2.jpg", 'OCR with Region']
240
+ ],
241
+ inputs=[input_img, task_prompt],
242
+ outputs=[output_text, output_img],
243
+ fn=process_image,
244
+ cache_examples=True,
245
+ label='Try examples'
246
+ )
247
 
248
+ submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
249
+
250
+ demo.launch(debug=True)