In the realm of visual data, efficient and accurate image retrieval is key. This blog post offers a step-by-step guide to building an image-based search engine using open-source tools.
By the end, you'll have the skills to create a robust, customizable search engine for precise image retrieval. This blog will empower you to explore visual data in new and exciting ways. Get ready to dive into image retrieval, unlocking the secrets of efficient visual data search with open-source tools!
Embedding the data
Embedding is crucial for building an image-based search engine as it transforms images into a format that can be easily compared and processed. This step enables efficient image retrieval by representing images as unique vectors in a high-dimensional space.
Let's first start by downloading our libraries. as for loadimg it is a python library that I developed to read images and convert them with ease, you can skip this library if you want. if you are interested about it and you want to contribute to its advancement you can checkout my github repository
After that let's load our model, I'm using CLIP here, but you can use any other similar model.
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification # or you can use CLIPProcessor, CLIPModel
device = 'cuda'if torch.cuda.is_available() else'cpu'
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device)
Now for the most important part, we will embed our dataset and add our embedding to a new column called embeddings.
It is recommended that you use a GPU here or any other accelerator since this is a very slow process.
it is recommended that you store your embedded data in a database it being locally, HF, pinecone, chromadb, or any other alternative to avoid embedding the dataset again.
Once your dataset is embedded and stored in a database we can now move on to defining the retrieval logic.
We will need to load the similar model that was used in the previous section here first.
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification # or you can use CLIPProcessor, CLIPModel
device = 'cuda'if torch.cuda.is_available() else'cpu'
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device)
We will also need to load our embedded dataset
from datasets import load_dataset
dataset = load_dataset("not-lain/embedded-pokemon", split="train")
You need to add a Faiss index to the embeddings column to set it up as a similarity search index.
Faiss is a library for efficient similarity search and clustering of dense vectors, and it is particularly useful for large-scale image retrieval tasks. By adding a Faiss index to the embeddings column, you enable fast and accurate nearest neighbor searches, making it efficient to retrieve similar images based on their embeddings. This step significantly enhances the performance of your image-based search engine, especially when dealing with a large number of images.
dataset = dataset.add_faiss_index("embeddings")
Now to retrieve the most similar images, you will need to first create the embedding of the new image, then retrieve the most similar entries from the dataset.
import numpy as np
defsearch(query: str, k: int = 4):
"""a function that embeds a new image and returns the most probable results"""
pixel_values = processor(images = query, return_tensors="pt")['pixel_values'] # embed new image
pixel_values = pixel_values.to(device)
img_emb = model.get_image_features(pixel_values)[0] # because it's a single element
img_emb = img_emb.cpu().detach().numpy() # convert to numpy because the datasets library does not support torch vectors
scores, retrieved_examples = dataset.get_nearest_examples( # retrieve results"embeddings", img_emb, # compare our new embedded image with the dataset embeddings
k=k # get only top k results
)
return retrieved_examples
Let's test our algorithm, to do this you can start by loading an image
from loadimg import load_img
image = load_img("https://img.pokemondb.net/artwork/large/charmander.jpg")
image
after that you can retrieve the most similar entries.
the entries are sorted in a decreasing order with the first entry being the most similar to our input image.
retrieved_examples = search(image)
let's visualize our results
import matplotlib.pyplot as plt
f, axarr = plt.subplots(2,2)
for index inrange(4):
i,j = index//2, index%2
axarr[i,j].set_title(retrieved_examples["text"][index])
axarr[i,j].imshow(retrieved_examples["image"][index])
axarr[i,j].axis('off')
plt.show()