scr930 commited on
Commit
1848536
1 Parent(s): 6b1798b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -1,26 +1,39 @@
1
  import gradio as gr
2
  from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
 
4
 
5
  # Load the model and processor
6
  model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
7
  processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
8
 
9
  def classify_image(image):
10
- # Preprocess the image
11
- inputs = processor(images=image, return_tensors="pt")
 
 
 
 
12
  # Perform the inference
13
  outputs = model(**inputs)
 
14
  # Postprocess the outputs
15
  logits_per_image = outputs.logits_per_image # this is the image-text similarity score
16
  probs = logits_per_image.softmax(dim=1) # we can use softmax to get probabilities
17
- return probs
 
 
 
 
 
 
 
18
 
19
  # Define Gradio interface
20
  iface = gr.Interface(
21
  fn=classify_image,
22
  inputs=gr.Image(type="pil"),
23
- outputs="text",
24
  title="Geolocal StreetCLIP Classification",
25
  description="Upload an image to classify using Geolocal StreetCLIP"
26
  )
 
1
  import gradio as gr
2
  from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
4
+ import torch
5
 
6
  # Load the model and processor
7
  model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
8
  processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
9
 
10
  def classify_image(image):
11
+ # Example labels for classification
12
+ labels = ["a photo of a cat", "a photo of a dog", "a photo of a car", "a photo of a tree"]
13
+
14
+ # Preprocess the image and text
15
+ inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
16
+
17
  # Perform the inference
18
  outputs = model(**inputs)
19
+
20
  # Postprocess the outputs
21
  logits_per_image = outputs.logits_per_image # this is the image-text similarity score
22
  probs = logits_per_image.softmax(dim=1) # we can use softmax to get probabilities
23
+
24
+ # Convert the probabilities to a list
25
+ probs_list = probs.tolist()[0]
26
+
27
+ # Create a dictionary of labels and probabilities
28
+ result = {label: prob for label, prob in zip(labels, probs_list)}
29
+
30
+ return result
31
 
32
  # Define Gradio interface
33
  iface = gr.Interface(
34
  fn=classify_image,
35
  inputs=gr.Image(type="pil"),
36
+ outputs="label",
37
  title="Geolocal StreetCLIP Classification",
38
  description="Upload an image to classify using Geolocal StreetCLIP"
39
  )