import cohere from annoy import AnnoyIndex import numpy as np import dotenv import os import pandas as pd dotenv.load_dotenv() model_name = "embed-english-v3.0" api_key = os.environ['COHERE_API_KEY'] input_type_embed = "search_document" # Set up the cohere client co = cohere.Client(api_key) # Get the dataset of topics topics = pd.read_csv("aicovers_topics.csv") # Get the embeddings list_embeds = co.embed(texts=list(topics['topic_cleaned']), model=model_name, input_type=input_type_embed).embeddings # Create the search index, pass the size of embedding search_index = AnnoyIndex(np.array(list_embeds).shape[1], metric='angular') # Add vectors to the search index for i in range(len(list_embeds)): search_index.add_item(i, list_embeds[i]) search_index.build(10) # 10 trees search_index.save('test.ann') def topic_from_caption(caption): """ Returns a topic from an uploaded list that is semantically similar to the input caption. Args: - caption (str): The image caption generated by MS Azure. Returns: - str: The extracted topic based on the provided caption. """ input_type_query = "search_query" caption_embed = co.embed(texts=[caption], model=model_name, input_type=input_type_query).embeddings # embeds a caption topic_ids = search_index.get_nns_by_vector(caption_embed[0], n=1, include_distances=True) # retrieves the nearest category topic = topics.iloc[topic_ids[0]]['topic_cleaned'].to_string(index=False, header=False) return topic