S23-CLIP / app.py
Shivdutta's picture
Upload 12 files
aa762fc verified
raw
history blame contribute delete
No virus
4.1 kB
#Acknowledgments:
#This project is inspired by:
#1. https://github.com/haltakov/natural-language-image-search by Vladimir Haltakov
#2. OpenAI's CLIP
#Importing all the necessary libraries
import torch
import requests
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from sentence_transformers import SentenceTransformer, util
device = "cuda" if torch.cuda.is_available() else "cpu"
examples = [["fruits.jpg"],
["flowers.jpg"],
["sa_10039.jpg"],
["dogs.jpg"],
["sa_8776.jpg"],
["sa_561.jpg"],
["sa_11025.jpg"],
["sa_1309.jpg"],
["sa_192.jpg"],
["sa_862.jpg"],["sa_414.jpg"]]
# Define model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
# Load data
photos = pd.read_csv("./photos.tsv000", sep='\t', header=0)
photo_features = np.load("./features.npy")
photo_ids = pd.read_csv("./photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])
def encode_text(text):
with torch.no_grad():
# Encode and normalize the description using CLIP
inputs = tokenizer([text], padding=True, return_tensors="pt")
inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
text_encoded = model.get_text_features(**inputs).detach().numpy()
return text_encoded
def encode_image(image):
image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
with torch.no_grad():
photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
image_encoded = search_photo_feature.cpu().numpy()
return image_encoded
T2I = "Text2Image"
I2I = "Image2Image"
def similarity(feature, photo_features):
similarities = list((feature @ photo_features.T).squeeze(0))
return similarities
def find_best_matches(image, mode, text):
# Compute the similarity between the descrption and each photo using the Cosine similarity
print ("Mode now ",mode)
if mode == "Text2Image":
# Encode the text input
text_features = encode_text(text)
feature = text_features
similarities = similarity(text_features, photo_features)
else:
#Encode the image input
image_features = encode_image(image)
feature = image_features
similarities = similarity(image_features, photo_features)
# Sort the photos by their similarity score
best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)
matched_images = []
for i in range(3):
# Retrieve the photo ID
idx = best_photos[i][1]
photo_id = photo_ids[idx]
# Get all metadata for this photo
photo_data = photos[photos["photo_id"] == photo_id].iloc[0]
# Display the images
#display(Image(url=photo_data["photo_image_url"] + "?w=640"))
response = requests.get(photo_data["photo_image_url"] + "?w=640")
img = PILIMAGE.open(BytesIO(response.content))
matched_images.append(img)
return matched_images
gr.Interface(fn=find_best_matches,
inputs=[
gr.Image(label="Image to search"),
gr.Radio([T2I, I2I]),
gr.Textbox(lines=1, label="Text query", placeholder="Introduce the search text...",
)],
theme="grass",examples=examples,
outputs=[gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", scale=2)],title="CLIP Image Search",
description="This application displays TOP THREE images from Unsplash dataset that best match the search query provided by the user. Moreover, the input can be provided via two modes ie text or image form.").launch()