File size: 3,220 Bytes
6e5968c
fd70df6
 
 
 
7d7b08f
fd70df6
 
29db0ee
 
7d7b08f
fd70df6
 
e3e9b30
fd70df6
5be384f
fd70df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c71394c
 
fd70df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6901a99
 
 
fd70df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from transformers import CLIPImageProcessor, pipeline, CLIPTokenizer, AutoModel
import torchvision.transforms as T
import torch.nn.functional as F
from PIL import Image, ImageFile
import requests
import torch
import numpy as np
import gradio as gr
import spaces


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name_or_path = "BAAI/EVA-CLIP-8B"

processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")

model = AutoModel.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True).to(device).eval()


tokenizer = CLIPTokenizer.from_pretrained(model_name_or_path)


clip_checkpoint = "openai/clip-vit-base-patch16"
clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification", device=device)


def infer_evaclip(image, captions):
  captions = captions.split(",")
  input_ids = tokenizer(captions,  return_tensors="pt", padding=True).input_ids.to(device)
  input_pixels = processor(images=image, return_tensors="pt", padding=True).pixel_values.to(device)


  with torch.no_grad(), torch.cuda.amp.autocast():
      image_features = model.encode_image(input_pixels)
      text_features = model.encode_text(input_ids)
      image_features /= image_features.norm(dim=-1, keepdim=True)
      text_features /= text_features.norm(dim=-1, keepdim=True)

  label_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
  label_probs = label_probs.cpu().numpy().tolist()[0]
  print(captions)
  print(label_probs)
  return {captions[i]: label_probs[i] for i in range(len(captions))}

def clip_inference(image, labels):
  candidate_labels = [label.lstrip(" ") for label in labels.split(",")]
  clip_out = clip_detector(image, candidate_labels=candidate_labels)
  return {out["label"]: float(out["score"]) for out in clip_out}

@spaces.GPU
def infer(image, labels):
  clip_out = clip_inference(image, labels)
  evaclip_out = infer_evaclip(image, labels)
  
  return clip_out, evaclip_out


with gr.Blocks() as demo:
  gr.Markdown("# EVACLIP vs CLIP πŸ’₯ ")
  gr.Markdown("[EVACLIP](https://huggingface.co/BAAI/EVA-CLIP-8B) is CLIP scaled to the moon! πŸ”₯")
  gr.Markdown("It's a state-of-the-art zero-shot image classification model, which is also outperforming predecessors on text-image retrieval and linear probing.")
  gr.Markdown("In this demo, compare EVACLIP outputs to CLIP outputs ✨")
  with gr.Row():
    with gr.Column():
        image_input = gr.Image(type="pil")
        text_input = gr.Textbox(label="Input a list of labels")
        run_button = gr.Button("Run", visible=True)

    with gr.Column():
      clip_output = gr.Label(label = "CLIP Output", num_top_classes=3)
      evaclip_output = gr.Label(label = "EVA-CLIP Output", num_top_classes=3)

  examples = [["./cat.png", "cat on a table, cat on a tree"]]
  gr.Examples(
        examples = examples,
        inputs=[image_input, text_input],
        outputs=[clip_output,
                 evaclip_output],
        fn=infer,
        cache_examples=True
    )
  run_button.click(fn=infer,
                    inputs=[image_input, text_input],
                    outputs=[clip_output,
                 evaclip_output])

demo.launch()