image_captioner / vector_search.py
Sverd's picture
upload from local pc
1352a28 verified
raw
history blame contribute delete
No virus
1.54 kB
import cohere
from annoy import AnnoyIndex
import numpy as np
import dotenv
import os
import pandas as pd
dotenv.load_dotenv()
model_name = "embed-english-v3.0"
api_key = os.environ['COHERE_API_KEY']
input_type_embed = "search_document"
# Set up the cohere client
co = cohere.Client(api_key)
# Get the dataset of topics
topics = pd.read_csv("aicovers_topics.csv")
# Get the embeddings
list_embeds = co.embed(texts=list(topics['topic_cleaned']), model=model_name, input_type=input_type_embed).embeddings
# Create the search index, pass the size of embedding
search_index = AnnoyIndex(np.array(list_embeds).shape[1], metric='angular')
# Add vectors to the search index
for i in range(len(list_embeds)):
search_index.add_item(i, list_embeds[i])
search_index.build(10) # 10 trees
search_index.save('test.ann')
def topic_from_caption(caption):
"""
Returns a topic from an uploaded list that is semantically similar to the input caption.
Args:
- caption (str): The image caption generated by MS Azure.
Returns:
- str: The extracted topic based on the provided caption.
"""
input_type_query = "search_query"
caption_embed = co.embed(texts=[caption], model=model_name, input_type=input_type_query).embeddings # embeds a caption
topic_ids = search_index.get_nns_by_vector(caption_embed[0], n=1, include_distances=True) # retrieves the nearest category
topic = topics.iloc[topic_ids[0]]['topic_cleaned'].to_string(index=False, header=False)
return topic