Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is a sample Python script.
|
2 |
+
|
3 |
+
# Press ⌃R to execute it or replace it with your code.
|
4 |
+
# Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings.
|
5 |
+
import base64
|
6 |
+
import datetime
|
7 |
+
import json
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import requests
|
11 |
+
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
12 |
+
import numpy as np
|
13 |
+
from io import BytesIO
|
14 |
+
import time
|
15 |
+
|
16 |
+
main_image_path = "/Users/aaron/Documents/temp/16pic_2415206_s.png"
|
17 |
+
API_TOKEN = "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"
|
18 |
+
headers = {"Authorization": f"Bearer {API_TOKEN}"}
|
19 |
+
# API_URL = "https://api-inference.huggingface.co/models/hustvl/yolos-tiny"
|
20 |
+
API_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50"
|
21 |
+
# API_OBJECT_URL = "https://api-inference.huggingface.co/models/microsoft/resnet-50"
|
22 |
+
API_SEGMENTATION_URL = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50-panoptic"
|
23 |
+
API_SEGMENTATION_URL_2 = "https://api-inference.huggingface.co/models/nvidia/segformer-b0-finetuned-ade-512-512"
|
24 |
+
|
25 |
+
temp_dir = "/Users/aaron/Documents/temp/imageai/"
|
26 |
+
|
27 |
+
|
28 |
+
def query(filename):
|
29 |
+
with open(filename, "rb") as f:
|
30 |
+
data = f.read()
|
31 |
+
response = requests.request("POST", API_URL, headers=headers, data=data)
|
32 |
+
return json.loads(response.content.decode("utf-8"))
|
33 |
+
|
34 |
+
|
35 |
+
def queryObjectDetection(filename):
|
36 |
+
with open(filename, "rb") as f:
|
37 |
+
data = f.read()
|
38 |
+
response = requests.request("POST", API_OBJECT_URL, headers=headers, data=data, timeout=6)
|
39 |
+
print(response)
|
40 |
+
return json.loads(response.content.decode("utf-8"))
|
41 |
+
|
42 |
+
|
43 |
+
def getImageSegmentation():
|
44 |
+
data = query(main_image_path)
|
45 |
+
print(data)
|
46 |
+
return data
|
47 |
+
|
48 |
+
|
49 |
+
def crop_image(box):
|
50 |
+
# 打开图片
|
51 |
+
image = Image.open(main_image_path)
|
52 |
+
|
53 |
+
# 计算裁剪区域
|
54 |
+
crop_area = (box['xmin'], box['ymin'], box['xmax'], box['ymax'])
|
55 |
+
|
56 |
+
# 裁剪图片
|
57 |
+
cropped_image = image.crop(crop_area)
|
58 |
+
|
59 |
+
return cropped_image
|
60 |
+
|
61 |
+
|
62 |
+
# # 示例
|
63 |
+
# image_path = "path/to/your/image.jpg"
|
64 |
+
# box = {'xmin': 186, 'ymin': 75, 'xmax': 252, 'ymax': 123}
|
65 |
+
#
|
66 |
+
# cropped_image = crop_image(image_path, box)
|
67 |
+
# cropped_image.show() # 显示裁剪后的图片
|
68 |
+
# cropped_image.save("path/to/save/cropped_image.jpg") # 保存裁剪后的图片
|
69 |
+
|
70 |
+
# Press the green button in the gutter to run the script.
|
71 |
+
# if __name__ == '__main__':
|
72 |
+
# data = getImageSegmentation()
|
73 |
+
# for item in data:
|
74 |
+
# box = item['box']
|
75 |
+
# cropped_image = crop_image(box)
|
76 |
+
# temp_image_path = temp_dir + str(int(datetime.datetime.now().timestamp() * 1000000)) + ".png"
|
77 |
+
# print(temp_image_path)
|
78 |
+
# cropped_image.save(temp_image_path)
|
79 |
+
# object_data = queryObjectDetection(temp_image_path)
|
80 |
+
# print(object_data)
|
81 |
+
# flag = False
|
82 |
+
# for obj in object_data:
|
83 |
+
# # 检查字典中是否包含 'error' 键
|
84 |
+
# if 'error' in obj and obj['error'] is not None:
|
85 |
+
# flag = True
|
86 |
+
# print("找到了一个包含 'error' 键的字典,且其值不为 None")
|
87 |
+
# else:
|
88 |
+
# print("字典不包含 'error' 键,或其值为 None")
|
89 |
+
# if flag:
|
90 |
+
# continue
|
91 |
+
# item['label'] = object_data[0]['label']
|
92 |
+
# print(data)
|
93 |
+
#
|
94 |
+
# ###下面就是画个图,和上面住流程无关,仅仅用于测试
|
95 |
+
# image = Image.open(main_image_path)
|
96 |
+
# draw = ImageDraw.Draw(image)
|
97 |
+
#
|
98 |
+
# # 设置边框颜色和字体
|
99 |
+
# border_color = (255, 0, 0) # 红色
|
100 |
+
# text_color = (255, 255, 255) # 白色
|
101 |
+
# font = ImageFont.truetype("Geneva.ttf", 12) # 使用 系统Geneva 字体,大小为 8
|
102 |
+
#
|
103 |
+
# # 遍历对象列表,画边框和标签
|
104 |
+
# for obj in data:
|
105 |
+
# label = obj['label']
|
106 |
+
# box = obj['box']
|
107 |
+
# xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
|
108 |
+
#
|
109 |
+
# # 画边框
|
110 |
+
# draw.rectangle([xmin, ymin, xmax, ymax], outline=border_color, width=2)
|
111 |
+
#
|
112 |
+
# # 画标签
|
113 |
+
# text_size = draw.textsize(label, font=font)
|
114 |
+
# draw.rectangle([xmin, ymin, xmin + text_size[0], ymin + text_size[1]], fill=border_color)
|
115 |
+
# draw.text((xmin, ymin), label, font=font, fill=text_color)
|
116 |
+
#
|
117 |
+
# image.show()
|
118 |
+
|
119 |
+
|
120 |
+
import numpy as np
|
121 |
+
from PIL import Image
|
122 |
+
import gradio as gr
|
123 |
+
import imageio
|
124 |
+
|
125 |
+
|
126 |
+
def send_request_to_api(img_byte_arr, max_retries=3, wait_time=60):
|
127 |
+
retry_count = 0
|
128 |
+
|
129 |
+
while retry_count < max_retries:
|
130 |
+
response = requests.request("POST", API_SEGMENTATION_URL, headers=headers, data=img_byte_arr)
|
131 |
+
response_content = response.content.decode("utf-8")
|
132 |
+
|
133 |
+
# 检查响应是否包含错误
|
134 |
+
if "error" in response_content:
|
135 |
+
print(f"Error: {response_content}")
|
136 |
+
retry_count += 1
|
137 |
+
time.sleep(wait_time)
|
138 |
+
else:
|
139 |
+
json_obj = json.loads(response_content)
|
140 |
+
return json_obj
|
141 |
+
|
142 |
+
raise Exception("Failed to get a valid response from the API after multiple retries.")
|
143 |
+
|
144 |
+
def getSegmentationMaskImage(input_img, blur_kernel_size=21):
|
145 |
+
# 调整输入图像的大小
|
146 |
+
target_width = 600
|
147 |
+
aspect_ratio = float(input_img.height) / float(input_img.width)
|
148 |
+
target_height = int(target_width * aspect_ratio)
|
149 |
+
input_img.thumbnail((target_width, target_height))
|
150 |
+
|
151 |
+
img_byte_arr = BytesIO()
|
152 |
+
input_img.save(img_byte_arr, format='PNG')
|
153 |
+
img_byte_arr = img_byte_arr.getvalue()
|
154 |
+
json_obj = send_request_to_api(img_byte_arr)
|
155 |
+
print(json_obj)
|
156 |
+
|
157 |
+
# 加载原始图像
|
158 |
+
original_image = input_img.copy()
|
159 |
+
|
160 |
+
# 如果原始图像不是RGBA模式,则将其转换为RGBA模式
|
161 |
+
if original_image.mode != 'RGBA':
|
162 |
+
original_image = original_image.convert('RGBA')
|
163 |
+
|
164 |
+
output_images = []
|
165 |
+
for item in json_obj:
|
166 |
+
label = item['label']
|
167 |
+
|
168 |
+
# 如果label以"LABEL"开头,则跳过此项
|
169 |
+
if label.startswith("LABEL"):
|
170 |
+
continue
|
171 |
+
mask_data = item['mask']
|
172 |
+
|
173 |
+
# 将Base64编码的mask数据解码为PNG图像
|
174 |
+
mask_image = Image.open(BytesIO(base64.b64decode(mask_data)))
|
175 |
+
|
176 |
+
# 将原始图像转换为OpenCV格式并应用高斯模糊
|
177 |
+
original_image_cv2 = cv2.cvtColor(np.array(original_image.convert('RGB')), cv2.COLOR_RGB2BGR)
|
178 |
+
blurred_image_cv2 = cv2.GaussianBlur(original_image_cv2, (blur_kernel_size, blur_kernel_size), 0)
|
179 |
+
|
180 |
+
# 将模糊图像转换回PIL格式,并将其转换回原始图像的颜色模式
|
181 |
+
blurred_image = Image.fromarray(cv2.cvtColor(blurred_image_cv2, cv2.COLOR_BGR2RGB)).convert(original_image.mode)
|
182 |
+
|
183 |
+
# 使用mask_image作为蒙版将原始图像的非模糊部分复制到模糊图像上
|
184 |
+
process_image = Image.composite(original_image, blurred_image, mask_image)
|
185 |
+
|
186 |
+
# 在mask位置添加红色文本和指向原始图像非模糊部分的红色线
|
187 |
+
draw = ImageDraw.Draw(process_image)
|
188 |
+
font = ImageFont.truetype("Geneva.ttf", 20) # 您可以选择其他字体和大小
|
189 |
+
text_position = (10, 20)
|
190 |
+
draw.text(text_position, label, font=font, fill=(255, 0, 0))
|
191 |
+
|
192 |
+
# 计算mask的边界框
|
193 |
+
mask_bbox = mask_image.getbbox()
|
194 |
+
|
195 |
+
# 计算mask边界框的顶部中心点
|
196 |
+
mask_top_center_x = (mask_bbox[0] + mask_bbox[2]) // 2
|
197 |
+
mask_top_center_y = mask_bbox[1]
|
198 |
+
|
199 |
+
# 计算文本框的底部中心点
|
200 |
+
text_width, text_height = draw.textsize(label, font=font)
|
201 |
+
text_bottom_center_x = text_position[0] + text_width // 2
|
202 |
+
text_bottom_center_y = text_position[1] + text_height
|
203 |
+
|
204 |
+
# 绘制一条从文本框底部中心到mask边界框顶部中心的红色线
|
205 |
+
draw.line([(text_bottom_center_x, text_bottom_center_y), (mask_top_center_x, mask_top_center_y)],
|
206 |
+
fill=(255, 0, 0), width=2)
|
207 |
+
|
208 |
+
output_images.append(process_image)
|
209 |
+
return output_images
|
210 |
+
|
211 |
+
def sepia(input_img):
|
212 |
+
# 检查输入图像的数据类型和值范围
|
213 |
+
if input_img.dtype == np.float32 and np.max(input_img) <= 1.0:
|
214 |
+
input_img = (input_img * 255).astype(np.uint8)
|
215 |
+
|
216 |
+
input_img = Image.fromarray(input_img)
|
217 |
+
output_images = getSegmentationMaskImage(input_img)
|
218 |
+
|
219 |
+
# 将所有图像堆叠在一起
|
220 |
+
stacked_image = np.vstack([np.array(img) for img in output_images])
|
221 |
+
return stacked_image
|
222 |
+
|
223 |
+
|
224 |
+
def imageDemo():
|
225 |
+
demo = gr.Interface(sepia, gr.Image(shape=None), gr.outputs.Image(label="Processed Images", type="numpy"),
|
226 |
+
title='Image Processing Demo')
|
227 |
+
demo.launch(share=True)
|
228 |
+
|
229 |
+
|
230 |
+
if __name__ == '__main__':
|
231 |
+
imageDemo()
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
#######---------gif输出方式
|
236 |
+
# def sepia(input_img):
|
237 |
+
# input_img = Image.fromarray((input_img * 255).astype(np.uint8))
|
238 |
+
#
|
239 |
+
# output_images = getSegmentationMaskImage(input_img)
|
240 |
+
#
|
241 |
+
# # 生成GIF动画
|
242 |
+
# buffered = BytesIO()
|
243 |
+
# output_images[0].save(buffered, format='GIF', save_all=True, append_images=output_images[1:], duration=3000, loop=0)
|
244 |
+
# gif_str = base64.b64encode(buffered.getvalue()).decode()
|
245 |
+
# return f'<img src="data:image/gif;base64,{gif_str}" width="400" />'
|
246 |
+
#
|
247 |
+
#
|
248 |
+
# def imageDemo():
|
249 |
+
# demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), gr.outputs.HTML(label="Processed Animation"), title='Sepia Filter Demo')
|
250 |
+
# demo.launch()
|