File size: 1,639 Bytes
2217d55
 
 
6f32cf9
04bef66
 
 
 
 
2217d55
6f32cf9
04bef66
 
6f32cf9
04bef66
 
 
 
 
 
 
 
 
 
2217d55
 
5c93d70
 
 
 
 
04bef66
 
 
 
 
5c93d70
04bef66
 
 
 
 
5c93d70
04bef66
2217d55
04bef66
2217d55
5c93d70
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import numpy as np
from PIL import Image
from pathlib import Path
import torch
from transformers import CLIPProcessor, CLIPModel


MODEL_NAME = "facebook/metaclip-b32-400m"

cache_path = Path('/app/cache')
if not cache_path.exists():
    cache_path = None


def get_clip_model_and_processor(model_name: str, cache_path: Path = None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if cache_path:
        model = CLIPModel.from_pretrained(model_name, cache_dir=str(cache_path)).to(device)
        processor = CLIPProcessor.from_pretrained(model_name, cache_dir=str(cache_path))
    else:
        model = CLIPModel.from_pretrained(model_name).to(device)
        processor = CLIPProcessor.from_pretrained(model_name)
    return model.eval(), processor


def image_to_embedding(img: np.ndarray = None, txt: str = None) -> np.ndarray:
    if img is None and not txt:
        return []
    
    if img is not None:
        embedding = CLIP_MODEL.get_image_features(
            **CLIP_PROCESSOR(images=[Image.fromarray(img)], return_tensors="pt", padding=True).to(
                CLIP_MODEL.device
            )
        )
    else:
        embedding = CLIP_MODEL.get_text_features(
            **CLIP_PROCESSOR(text=[txt], return_tensors="pt", padding=True).to(
                CLIP_MODEL.device
            )
        )

    return embedding.detach().cpu().numpy()

CLIP_MODEL, CLIP_PROCESSOR = get_clip_model_and_processor(MODEL_NAME, cache_path=cache_path)

demo = gr.Interface(fn=image_to_embedding, inputs=["image", "textbox"], outputs="textbox", cache_examples=True)
demo.launch(server_name="0.0.0.0")