fancyfeast commited on
Commit
e08e861
1 Parent(s): 128e51e

More results and adds a tag string output

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -9,6 +9,10 @@ import torchvision.transforms.functional as TVF
9
 
10
 
11
  MODEL_REPO = "fancyfeast/joytag"
 
 
 
 
12
 
13
 
14
  def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor:
@@ -45,7 +49,11 @@ def predict(image: Image.Image):
45
  preds = model(batch)
46
  tag_preds = preds['tags'].sigmoid().cpu()
47
 
48
- return {top_tags[i]: tag_preds[0][i] for i in range(len(top_tags))}
 
 
 
 
49
 
50
 
51
  print("Downloading model...")
@@ -62,8 +70,13 @@ print("Starting server...")
62
  gradio_app = gr.Interface(
63
  predict,
64
  inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'),
65
- outputs=[gr.Label(label="Result", num_top_classes=5)],
 
 
 
66
  title="JoyTag",
 
 
67
  )
68
 
69
 
 
9
 
10
 
11
  MODEL_REPO = "fancyfeast/joytag"
12
+ THRESHOLD = 0.4
13
+ DESCRIPTION = """
14
+ Demo for the JoyTag model: https://huggingface.co/fancyfeast/joytag
15
+ """
16
 
17
 
18
  def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor:
 
49
  preds = model(batch)
50
  tag_preds = preds['tags'].sigmoid().cpu()
51
 
52
+ scores = {top_tags[i]: tag_preds[0][i] for i in range(len(top_tags))}
53
+ predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD]
54
+ tag_string = ', '.join(predicted_tags)
55
+
56
+ return tag_string, scores
57
 
58
 
59
  print("Downloading model...")
 
70
  gradio_app = gr.Interface(
71
  predict,
72
  inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'),
73
+ outputs=[
74
+ gr.Textbox(label="Tag String"),
75
+ gr.Label(label="Tag Predictions", num_top_classes=100),
76
+ ],
77
  title="JoyTag",
78
+ description=DESCRIPTION,
79
+ allow_flagging="never",
80
  )
81
 
82