|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import requests |
|
import numpy as np |
|
import pandas as pd |
|
import gradio as gr |
|
from io import BytesIO |
|
|
|
from PIL import Image as PILIMAGE |
|
|
|
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer |
|
from sentence_transformers import SentenceTransformer, util |
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
photos = pd.read_csv("./photos.tsv000", sep='\t', header=0) |
|
photo_features = np.load("./features.npy") |
|
photo_ids = pd.read_csv("./photo_ids.csv") |
|
photo_ids = list(photo_ids['photo_id']) |
|
|
|
|
|
|
|
def encode_text(text): |
|
with torch.no_grad(): |
|
|
|
inputs = tokenizer([text], padding=True, return_tensors="pt") |
|
inputs = processor(text=[text], images=None, return_tensors="pt", padding=True) |
|
text_encoded = model.get_text_features(**inputs).detach().numpy() |
|
return text_encoded |
|
|
|
def encode_image(image): |
|
image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB') |
|
with torch.no_grad(): |
|
photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"] |
|
search_photo_feature = model.get_image_features(photo_preprocessed.to(device)) |
|
search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True) |
|
image_encoded = search_photo_feature.cpu().numpy() |
|
return image_encoded |
|
|
|
T2I = "Text2Image" |
|
I2I = "Image2Image" |
|
|
|
def similarity(feature, photo_features): |
|
similarities = list((feature @ photo_features.T).squeeze(0)) |
|
return similarities |
|
|
|
def find_best_matches(image, mode, text): |
|
|
|
print ("Mode now ",mode) |
|
|
|
if mode == "Text2Image": |
|
|
|
text_features = encode_text(text) |
|
feature = text_features |
|
similarities = similarity(text_features, photo_features) |
|
|
|
|
|
else: |
|
|
|
image_features = encode_image(image) |
|
feature = image_features |
|
similarities = similarity(image_features, photo_features) |
|
|
|
|
|
best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True) |
|
|
|
matched_images = [] |
|
for i in range(3): |
|
|
|
idx = best_photos[i][1] |
|
photo_id = photo_ids[idx] |
|
|
|
|
|
photo_data = photos[photos["photo_id"] == photo_id].iloc[0] |
|
|
|
|
|
|
|
response = requests.get(photo_data["photo_image_url"] + "?w=640") |
|
img = PILIMAGE.open(BytesIO(response.content)) |
|
matched_images.append(img) |
|
return matched_images |
|
|
|
|
|
|
|
|
|
gr.Interface(fn=find_best_matches, |
|
inputs=[ |
|
gr.Image(label="Image to search"), |
|
gr.Radio([T2I, I2I]), |
|
gr.Textbox(lines=1, label="Text query", placeholder="Introduce the search text...", |
|
)], |
|
theme="grass", |
|
outputs=[gr.Gallery( |
|
label="Generated images", show_label=False, elem_id="gallery" |
|
).style(grid=[2], height="auto")], enable_queue=True, title="CLIP Image Search", |
|
description="This application displays TOP THREE images from Unsplash dataset that best match the search query provided by the user. Moreover, the input can be provided via two modes ie text or image form.").launch() |
|
|
|
|
|
|