robgonsalves commited on
Commit
e733601
1 Parent(s): aa3546b

add menu for Similarity Type

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -7,11 +7,7 @@ from transformers import CLIPProcessor, CLIPModel
7
  model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
8
  processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
9
 
10
- def calculate_similarity(image, text_prompt):
11
- # Ensure text_prompt is a string
12
- if not isinstance(text_prompt, str):
13
- text_prompt = str(text_prompt)
14
-
15
  # Process inputs
16
  inputs = processor(images=image, text=text_prompt, return_tensors="pt", padding=True)
17
 
@@ -22,25 +18,28 @@ def calculate_similarity(image, text_prompt):
22
  image_features = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
23
  text_features = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
24
  cosine_similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
25
-
26
- # Adjusting the similarity score
27
- adjusted_similarity = cosine_similarity.item() * 3 * 100
28
- clipped_similarity = min(adjusted_similarity, 99.99)
29
- formatted_similarity = f"According to OpenCLIP, the image and the text prompt are {clipped_similarity:.2f}% similar."
30
 
31
- return formatted_similarity
 
 
 
 
 
 
 
32
 
33
  # Set up Gradio interface
34
  iface = gr.Interface(
35
  fn=calculate_similarity,
36
  inputs=[
37
  gr.Image(type="pil", label="Upload Image", height=512),
38
- gr.Textbox(label="Text Prompt")
 
39
  ],
40
  outputs=gr.Text(),
41
  allow_flagging="never",
42
- title="OpenClip Cosine Similarity Calculator",
43
- description="Provide a text prompt and upload an image to calculate the cosine similarity."
44
  )
45
 
46
  # Launch the interface with a public link for sharing online
 
7
  model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
8
  processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
9
 
10
+ def calculate_similarity(image, text_prompt, similarity_type):
 
 
 
 
11
  # Process inputs
12
  inputs = processor(images=image, text=text_prompt, return_tensors="pt", padding=True)
13
 
 
18
  image_features = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
19
  text_features = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
20
  cosine_similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
 
 
 
 
 
21
 
22
+ # Adjusting the similarity score based on the dropdown selection
23
+ if similarity_type == "General Similarity (3x scaled)":
24
+ adjusted_similarity = cosine_similarity.item() * 3 * 100
25
+ result_text = f"According to OpenCLIP, the image and the text prompt have a general similarity of {min(adjusted_similarity, 99.99):.2f}%."
26
+ else: # Cosine Similarity (raw)
27
+ result_text = f"According to OpenCLIP, the image and the text prompt have a cosine similarity of {cosine_similarity.item() * 100:.2f}%."
28
+
29
+ return result_text
30
 
31
  # Set up Gradio interface
32
  iface = gr.Interface(
33
  fn=calculate_similarity,
34
  inputs=[
35
  gr.Image(type="pil", label="Upload Image", height=512),
36
+ gr.Textbox(label="Text Prompt"),
37
+ gr.Dropdown(label="Similarity Type", choices=["General Similarity (3x scaled)", "Cosine Similarity (raw)"], value="General Similarity (3x scaled)")
38
  ],
39
  outputs=gr.Text(),
40
  allow_flagging="never",
41
+ title="OpenClip Similarity Calculator",
42
+ description="Upload an image and provide a text prompt to calculate the similarity."
43
  )
44
 
45
  # Launch the interface with a public link for sharing online