File size: 1,535 Bytes
1352a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import cohere
from annoy import AnnoyIndex
import numpy as np
import dotenv
import os
import pandas as pd

dotenv.load_dotenv()

model_name = "embed-english-v3.0"
api_key = os.environ['COHERE_API_KEY']
input_type_embed = "search_document"

# Set up the cohere client
co = cohere.Client(api_key)

# Get the dataset of topics
topics = pd.read_csv("aicovers_topics.csv")

# Get the embeddings
list_embeds = co.embed(texts=list(topics['topic_cleaned']), model=model_name, input_type=input_type_embed).embeddings

# Create the search index, pass the size of embedding
search_index = AnnoyIndex(np.array(list_embeds).shape[1], metric='angular')

# Add vectors to the search index
for i in range(len(list_embeds)):
    search_index.add_item(i, list_embeds[i])
search_index.build(10)  # 10 trees
search_index.save('test.ann')


def topic_from_caption(caption):
    """
        Returns a topic from an uploaded list that is semantically similar to the input caption.

        Args:
        - caption (str): The image caption generated by MS Azure.

        Returns:
        - str: The extracted topic based on the provided caption.
        """
    input_type_query = "search_query"
    caption_embed = co.embed(texts=[caption], model=model_name, input_type=input_type_query).embeddings  # embeds a caption
    topic_ids = search_index.get_nns_by_vector(caption_embed[0], n=1, include_distances=True)  # retrieves the nearest category
    topic = topics.iloc[topic_ids[0]]['topic_cleaned'].to_string(index=False, header=False)
    return topic