Tonic commited on
Commit
0308f3c
1 Parent(s): d24a7c2

refactor vectordb

Browse files
Files changed (1) hide show
  1. app.py +110 -29
app.py CHANGED
@@ -98,40 +98,121 @@ class MyEmbeddingFunction(EmbeddingFunction):
98
  embeddings = [self.embedding_generator.compute_embeddings(doc) for doc in input]
99
  embeddings = [item for sublist in embeddings for item in sublist]
100
  return embeddings
101
-
102
- class DocumentLoader:
103
- def __init__(self, file_path: str, mode: str = "elements"):
104
- self.file_path = file_path
105
- self.mode = mode
106
-
107
- def load_documents(self):
108
- loader = UnstructuredFileLoader(self.file_path, mode=self.mode)
109
- docs = loader.load()
110
- return [doc.page_content for doc in docs]
111
-
112
- class ChromaManager:
113
- def __init__(self, collection_name: str, embedding_function: MyEmbeddingFunction):
114
- self.client = HttpClient(settings=Settings(allow_reset=True))
115
- self.client.reset() # resets the database
116
- self.collection = self.client.create_collection(collection_name)
117
- self.embedding_function = embedding_function
118
-
119
- def add_documents(self, documents: list):
120
- for doc in documents:
121
- self.collection.add(ids=[str(uuid.uuid1())], documents=[doc], embeddings=self.embedding_function([doc]))
122
-
123
- def query(self, query_text: str):
124
- db = Chroma(client=self.client, collection_name=self.collection.name, embedding_function=self.embedding_function)
125
- result_docs = db.similarity_search(query_text)
126
- return result_docs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
-
129
  # Initialize clients
130
  intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
131
  embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
132
  embedding_function = MyEmbeddingFunction(embedding_generator=embedding_generator)
133
- chroma_manager = ChromaManager(collection_name="Tonic-instruct" , embedding_function=embedding_function)
134
-
135
  def respond(
136
  message,
137
  history: list[tuple[str, str]],
 
98
  embeddings = [self.embedding_generator.compute_embeddings(doc) for doc in input]
99
  embeddings = [item for sublist in embeddings for item in sublist]
100
  return embeddings
101
+ # main.py
102
+ import os
103
+ import uuid
104
+ import gradio as gr
105
+ import torch
106
+ import torch.nn.functional as F
107
+ from torch.nn import DataParallel
108
+ from torch import Tensor
109
+ from transformers import AutoTokenizer, AutoModel
110
+ from huggingface_hub import InferenceClient
111
+ from openai import OpenAI
112
+ from langchain_community.document_loaders import UnstructuredFileLoader
113
+ from chromadb import Documents, EmbeddingFunction, Embeddings
114
+ from chromadb.config import Settings
115
+ from chromadb import HttpClient
116
+ from utils import load_env_variables, parse_and_route
117
+ from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name
118
+
119
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
120
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
121
+ os.environ['CUDA_CACHE_DISABLE'] = '1'
122
+
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+
125
+ ### Utils
126
+ hf_token, yi_token = load_env_variables()
127
+
128
+ def clear_cuda_cache():
129
+ torch.cuda.empty_cache()
130
+
131
+ client = OpenAI(api_key=yi_token, base_url=API_BASE)
132
+
133
+ class EmbeddingGenerator:
134
+ def __init__(self, model_name: str, token: str, intention_client):
135
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
136
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True)
137
+ self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device)
138
+ self.intention_client = intention_client
139
+
140
+ def clear_cuda_cache(self):
141
+ torch.cuda.empty_cache()
142
+
143
+ @spaces.GPU
144
+ def compute_embeddings(self, input_text: str):
145
+ # Get the intention
146
+ intention_completion = self.intention_client.chat.completions.create(
147
+ model="yi-large",
148
+ messages=[
149
+ {"role": "system", "content": intention_prompt},
150
+ {"role": "user", "content": input_text}
151
+ ]
152
+ )
153
+ intention_output = intention_completion.choices[0].message['content']
154
+
155
+ # Parse and route the intention
156
+ parsed_task = parse_and_route(intention_output)
157
+ selected_task = list(parsed_task.keys())[0]
158
+
159
+ # Construct the prompt
160
+ try:
161
+ task_description = tasks[selected_task]
162
+ except KeyError:
163
+ print(f"Selected task not found: {selected_task}")
164
+ return f"Error: Task '{selected_task}' not found. Please select a valid task."
165
+
166
+ query_prefix = f"Instruct: {task_description}\nQuery: "
167
+ queries = [input_text]
168
+
169
+ # Get the embeddings
170
+ with torch.no_grad():
171
+ inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
172
+ outputs = self.model(**inputs)
173
+ query_embeddings = outputs.last_hidden_state.mean(dim=1)
174
+
175
+ # Normalize embeddings
176
+ query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
177
+ embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
178
+ self.clear_cuda_cache()
179
+ return embeddings_list
180
+
181
+ class MyEmbeddingFunction(EmbeddingFunction):
182
+ def __init__(self, embedding_generator: EmbeddingGenerator):
183
+ self.embedding_generator = embedding_generator
184
+
185
+ def __call__(self, input: Documents) -> Embeddings:
186
+ embeddings = [self.embedding_generator.compute_embeddings(doc) for doc in input]
187
+ embeddings = [item for sublist in embeddings for item in sublist]
188
+ return embeddings
189
+
190
+ def load_documents(file_path: str, mode: str = "elements"):
191
+ loader = UnstructuredFileLoader(file_path, mode=mode)
192
+ docs = loader.load()
193
+ return [doc.page_content for doc in docs]
194
+
195
+ def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction):
196
+ client = HttpClient(settings=Settings(allow_reset=True))
197
+ client.reset() # resets the database
198
+ collection = client.create_collection(collection_name)
199
+ return client, collection
200
+
201
+ def add_documents_to_chroma(client, collection, documents: list, embedding_function: MyEmbeddingFunction):
202
+ for doc in documents:
203
+ collection.add(ids=[str(uuid.uuid1())], documents=[doc], embeddings=embedding_function([doc]))
204
+
205
+ def query_chroma(client, collection_name: str, query_text: str, embedding_function: MyEmbeddingFunction):
206
+ db = Chroma(client=client, collection_name=collection_name, embedding_function=embedding_function)
207
+ result_docs = db.similarity_search(query_text)
208
+ return result_docs
209
 
 
210
  # Initialize clients
211
  intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
212
  embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
213
  embedding_function = MyEmbeddingFunction(embedding_generator=embedding_generator)
214
+ chroma_client, chroma_collection = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function)
215
+
216
  def respond(
217
  message,
218
  history: list[tuple[str, str]],