risekid commited on
Commit
91f15c9
1 Parent(s): d19acaa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -0
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a sample Python script.
2
+
3
+ # Press ⌃R to execute it or replace it with your code.
4
+ # Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings.
5
+ import base64
6
+ import datetime
7
+ import json
8
+
9
+ import cv2
10
+ import requests
11
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
12
+ import numpy as np
13
+ from io import BytesIO
14
+ import time
15
+
16
+ main_image_path = "/Users/aaron/Documents/temp/16pic_2415206_s.png"
17
+ API_TOKEN = "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"
18
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
19
+ # API_URL = "https://api-inference.huggingface.co/models/hustvl/yolos-tiny"
20
+ API_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50"
21
+ # API_OBJECT_URL = "https://api-inference.huggingface.co/models/microsoft/resnet-50"
22
+ API_SEGMENTATION_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50-panoptic"
23
+ API_SEGMENTATION_URL_2 = "https://api-inference.huggingface.co/models/nvidia/segformer-b0-finetuned-ade-512-512"
24
+
25
+ temp_dir = "/Users/aaron/Documents/temp/imageai/"
26
+
27
+
28
+ def query(filename):
29
+ with open(filename, "rb") as f:
30
+ data = f.read()
31
+ response = requests.request("POST", API_URL, headers=headers, data=data)
32
+ return json.loads(response.content.decode("utf-8"))
33
+
34
+
35
+ def queryObjectDetection(filename):
36
+ with open(filename, "rb") as f:
37
+ data = f.read()
38
+ response = requests.request("POST", API_OBJECT_URL, headers=headers, data=data, timeout=6)
39
+ print(response)
40
+ return json.loads(response.content.decode("utf-8"))
41
+
42
+
43
+ def getImageSegmentation():
44
+ data = query(main_image_path)
45
+ print(data)
46
+ return data
47
+
48
+
49
+ def crop_image(box):
50
+ # 打开图片
51
+ image = Image.open(main_image_path)
52
+
53
+ # 计算裁剪区域
54
+ crop_area = (box['xmin'], box['ymin'], box['xmax'], box['ymax'])
55
+
56
+ # 裁剪图片
57
+ cropped_image = image.crop(crop_area)
58
+
59
+ return cropped_image
60
+
61
+
62
+ # # 示例
63
+ # image_path = "path/to/your/image.jpg"
64
+ # box = {'xmin': 186, 'ymin': 75, 'xmax': 252, 'ymax': 123}
65
+ #
66
+ # cropped_image = crop_image(image_path, box)
67
+ # cropped_image.show() # 显示裁剪后的图片
68
+ # cropped_image.save("path/to/save/cropped_image.jpg") # 保存裁剪后的图片
69
+
70
+ # Press the green button in the gutter to run the script.
71
+ # if __name__ == '__main__':
72
+ # data = getImageSegmentation()
73
+ # for item in data:
74
+ # box = item['box']
75
+ # cropped_image = crop_image(box)
76
+ # temp_image_path = temp_dir + str(int(datetime.datetime.now().timestamp() * 1000000)) + ".png"
77
+ # print(temp_image_path)
78
+ # cropped_image.save(temp_image_path)
79
+ # object_data = queryObjectDetection(temp_image_path)
80
+ # print(object_data)
81
+ # flag = False
82
+ # for obj in object_data:
83
+ # # 检查字典中是否包含 'error' 键
84
+ # if 'error' in obj and obj['error'] is not None:
85
+ # flag = True
86
+ # print("找到了一个包含 'error' 键的字典,且其值不为 None")
87
+ # else:
88
+ # print("字典不包含 'error' 键,或其值为 None")
89
+ # if flag:
90
+ # continue
91
+ # item['label'] = object_data[0]['label']
92
+ # print(data)
93
+ #
94
+ # ###下面就是画个图,和上面住流程无关,仅仅用于测试
95
+ # image = Image.open(main_image_path)
96
+ # draw = ImageDraw.Draw(image)
97
+ #
98
+ # # 设置边框颜色和字体
99
+ # border_color = (255, 0, 0) # 红色
100
+ # text_color = (255, 255, 255) # 白色
101
+ # font = ImageFont.truetype("Geneva.ttf", 12) # 使用 系统Geneva 字体,大小为 8
102
+ #
103
+ # # 遍历对象列表,画边框和标签
104
+ # for obj in data:
105
+ # label = obj['label']
106
+ # box = obj['box']
107
+ # xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
108
+ #
109
+ # # 画边框
110
+ # draw.rectangle([xmin, ymin, xmax, ymax], outline=border_color, width=2)
111
+ #
112
+ # # 画标签
113
+ # text_size = draw.textsize(label, font=font)
114
+ # draw.rectangle([xmin, ymin, xmin + text_size[0], ymin + text_size[1]], fill=border_color)
115
+ # draw.text((xmin, ymin), label, font=font, fill=text_color)
116
+ #
117
+ # image.show()
118
+
119
+
120
+ import numpy as np
121
+ from PIL import Image
122
+ import gradio as gr
123
+ import imageio
124
+
125
+
126
+ def send_request_to_api(img_byte_arr, max_retries=3, wait_time=60):
127
+ retry_count = 0
128
+
129
+ while retry_count < max_retries:
130
+ response = requests.request("POST", API_SEGMENTATION_URL, headers=headers, data=img_byte_arr)
131
+ response_content = response.content.decode("utf-8")
132
+
133
+ # 检查响应是否包含错误
134
+ if "error" in response_content:
135
+ print(f"Error: {response_content}")
136
+ retry_count += 1
137
+ time.sleep(wait_time)
138
+ else:
139
+ json_obj = json.loads(response_content)
140
+ return json_obj
141
+
142
+ raise Exception("Failed to get a valid response from the API after multiple retries.")
143
+
144
+ def getSegmentationMaskImage(input_img, blur_kernel_size=21):
145
+ # 调整输入图像的大小
146
+ target_width = 600
147
+ aspect_ratio = float(input_img.height) / float(input_img.width)
148
+ target_height = int(target_width * aspect_ratio)
149
+ input_img.thumbnail((target_width, target_height))
150
+
151
+ img_byte_arr = BytesIO()
152
+ input_img.save(img_byte_arr, format='PNG')
153
+ img_byte_arr = img_byte_arr.getvalue()
154
+ json_obj = send_request_to_api(img_byte_arr)
155
+ print(json_obj)
156
+
157
+ # 加载原始图像
158
+ original_image = input_img.copy()
159
+
160
+ # 如果原始图像不是RGBA模式,则将其转换为RGBA模式
161
+ if original_image.mode != 'RGBA':
162
+ original_image = original_image.convert('RGBA')
163
+
164
+ output_images = []
165
+ for item in json_obj:
166
+ label = item['label']
167
+
168
+ # 如果label以"LABEL"开头,则跳过此项
169
+ if label.startswith("LABEL"):
170
+ continue
171
+ mask_data = item['mask']
172
+
173
+ # 将Base64编码的mask数据解码为PNG图像
174
+ mask_image = Image.open(BytesIO(base64.b64decode(mask_data)))
175
+
176
+ # 将原始图像转换为OpenCV格式并应用高斯模糊
177
+ original_image_cv2 = cv2.cvtColor(np.array(original_image.convert('RGB')), cv2.COLOR_RGB2BGR)
178
+ blurred_image_cv2 = cv2.GaussianBlur(original_image_cv2, (blur_kernel_size, blur_kernel_size), 0)
179
+
180
+ # 将模糊图像转换回PIL格式,并将其转换回原始图像的颜色模式
181
+ blurred_image = Image.fromarray(cv2.cvtColor(blurred_image_cv2, cv2.COLOR_BGR2RGB)).convert(original_image.mode)
182
+
183
+ # 使用mask_image作为蒙版将原始图像的非模糊部分复制到模糊图像上
184
+ process_image = Image.composite(original_image, blurred_image, mask_image)
185
+
186
+ # 在mask位置添加红色文本和指向原始图像非模糊部分的红色线
187
+ draw = ImageDraw.Draw(process_image)
188
+ font = ImageFont.truetype("Geneva.ttf", 20) # 您可以选择其他字体和大小
189
+ text_position = (10, 20)
190
+ draw.text(text_position, label, font=font, fill=(255, 0, 0))
191
+
192
+ # 计算mask的边界框
193
+ mask_bbox = mask_image.getbbox()
194
+
195
+ # 计算mask边界框的顶部中心点
196
+ mask_top_center_x = (mask_bbox[0] + mask_bbox[2]) // 2
197
+ mask_top_center_y = mask_bbox[1]
198
+
199
+ # 计算文本框的底部中心点
200
+ text_width, text_height = draw.textsize(label, font=font)
201
+ text_bottom_center_x = text_position[0] + text_width // 2
202
+ text_bottom_center_y = text_position[1] + text_height
203
+
204
+ # 绘制一条从文本框底部中心到mask边界框顶部中心的红色线
205
+ draw.line([(text_bottom_center_x, text_bottom_center_y), (mask_top_center_x, mask_top_center_y)],
206
+ fill=(255, 0, 0), width=2)
207
+
208
+ output_images.append(process_image)
209
+ return output_images
210
+
211
+ def sepia(input_img):
212
+ # 检查输入图像的数据类型和值范围
213
+ if input_img.dtype == np.float32 and np.max(input_img) <= 1.0:
214
+ input_img = (input_img * 255).astype(np.uint8)
215
+
216
+ input_img = Image.fromarray(input_img)
217
+ output_images = getSegmentationMaskImage(input_img)
218
+
219
+ # 将所有图像堆叠在一起
220
+ stacked_image = np.vstack([np.array(img) for img in output_images])
221
+ return stacked_image
222
+
223
+
224
+ def imageDemo():
225
+ demo = gr.Interface(sepia, gr.Image(shape=None), gr.outputs.Image(label="Processed Images", type="numpy"),
226
+ title='Image Processing Demo')
227
+ demo.launch(share=True)
228
+
229
+
230
+ if __name__ == '__main__':
231
+ imageDemo()
232
+
233
+
234
+
235
+ #######---------gif输出方式
236
+ # def sepia(input_img):
237
+ # input_img = Image.fromarray((input_img * 255).astype(np.uint8))
238
+ #
239
+ # output_images = getSegmentationMaskImage(input_img)
240
+ #
241
+ # # 生成GIF动画
242
+ # buffered = BytesIO()
243
+ # output_images[0].save(buffered, format='GIF', save_all=True, append_images=output_images[1:], duration=3000, loop=0)
244
+ # gif_str = base64.b64encode(buffered.getvalue()).decode()
245
+ # return f'<img src="data:image/gif;base64,{gif_str}" width="400" />'
246
+ #
247
+ #
248
+ # def imageDemo():
249
+ # demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), gr.outputs.HTML(label="Processed Animation"), title='Sepia Filter Demo')
250
+ # demo.launch()