fancyfeast commited on
Commit
3be01e3
1 Parent(s): d240551

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +79 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import huggingface_hub
3
+ from PIL import Image
4
+ from pathlib import Path
5
+ import onnxruntime as rt
6
+ import numpy as np
7
+ import csv
8
+
9
+
10
+ MODEL_REPO = 'toynya/Z3D-E621-Convnext'
11
+ THRESHOLD = 0.5
12
+ DESCRIPTION = """
13
+ This is a demo of https://huggingface.co/toynya/Z3D-E621-Convnext
14
+ I am not affiliated with the model author in anyway, this is just a useful tool requested by a user.
15
+ """
16
+
17
+
18
+ def prepare_image(image: Image.Image, target_size: int):
19
+ # Pad image to square
20
+ image_shape = image.size
21
+ max_dim = max(image_shape)
22
+ pad_left = (max_dim - image_shape[0]) // 2
23
+ pad_top = (max_dim - image_shape[1]) // 2
24
+
25
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
26
+ padded_image.paste(image, (pad_left, pad_top))
27
+
28
+ # Resize
29
+ if max_dim != target_size:
30
+ padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
31
+
32
+ # Convert to numpy array
33
+ image_array = np.asarray(padded_image, dtype=np.float32) / 255.0
34
+ return np.expand_dims(image_array, axis=0)
35
+
36
+
37
+ def predict(image: Image.Image):
38
+ image_array = prepare_image(image, 448)
39
+
40
+ image_array = prepare_image(image, 448)
41
+ input_name = 'input_1:0'
42
+ output_name = 'predictions_sigmoid'
43
+
44
+ result = session.run([output_name], {input_name: image_array})
45
+ result = result[0][0]
46
+
47
+ scores = {tags[i]: result[i] for i in range(len(result))}
48
+ predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD]
49
+ tag_string = ', '.join(predicted_tags)
50
+
51
+ return tag_string, scores
52
+
53
+
54
+ print("Downloading model...")
55
+ path = huggingface_hub.snapshot_download(MODEL_REPO)
56
+ print("Loading model...")
57
+ session = rt.InferenceSession(path / 'model.onnx', providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
58
+
59
+ with open(path / 'tags-selected.csv', mode='r', encoding='utf-8') as file:
60
+ csv_reader = csv.DictReader(file)
61
+ tags = [row['name'].strip() for row in csv_reader]
62
+
63
+ print("Starting server...")
64
+
65
+ gradio_app = gr.Interface(
66
+ predict,
67
+ inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'),
68
+ outputs=[
69
+ gr.Textbox(label="Tag String"),
70
+ gr.Label(label="Tag Predictions", num_top_classes=100),
71
+ ],
72
+ title="JoyTag",
73
+ description=DESCRIPTION,
74
+ allow_flagging="never",
75
+ )
76
+
77
+
78
+ if __name__ == '__main__':
79
+ gradio_app.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ numpy==1.26.3