Tonic commited on
Commit
8f1cff4
1 Parent(s): 556479b

remove duplicate code

Browse files
Files changed (1) hide show
  1. app.py +0 -100
app.py CHANGED
@@ -1,103 +1,3 @@
1
- # main.py
2
- import spaces
3
- import torch
4
- import torch.nn.functional as F
5
- from torch.nn import DataParallel
6
- from torch import Tensor
7
- from transformers import AutoTokenizer, AutoModel
8
- import threading
9
- import queue
10
- import os
11
- import json
12
- import numpy as np
13
- import gradio as gr
14
- from huggingface_hub import InferenceClient
15
- import openai
16
- from openai import OpenAI
17
- from globalvars import API_BASE, intention_prompt, tasks , system_message, model_name
18
- from dotenv import load_dotenv
19
- import re
20
- from utils import load_env_variables
21
- import chromadb
22
- from chromadb import Documents, EmbeddingFunction, Embeddings
23
- from chromadb.config import Settings
24
- from chromadb import HttpClient
25
- from langchain_community.document_loaders import UnstructuredFileLoader
26
- from utils import load_env_variables , parse_and_route
27
-
28
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
29
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
30
- os.environ['CUDA_CACHE_DISABLE'] = '1'
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
-
33
- hf_token, yi_token = load_env_variables()
34
-
35
- def clear_cuda_cache():
36
- torch.cuda.empty_cache()
37
-
38
- client = OpenAI(
39
- api_key=yi_token,
40
- base_url=API_BASE
41
- )
42
-
43
-
44
- class EmbeddingGenerator:
45
- def __init__(self, model_name: str, token: str, intention_client):
46
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True)
48
- self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device)
49
- self.intention_client = intention_client
50
-
51
- def clear_cuda_cache(self):
52
- torch.cuda.empty_cache()
53
-
54
- @spaces.GPU
55
- def compute_embeddings(self, input_text: str):
56
- # Get the intention
57
- intention_completion = self.intention_client.chat.completions.create(
58
- model="yi-large",
59
- messages=[
60
- {"role": "system", "content": intention_prompt},
61
- {"role": "user", "content": input_text}
62
- ]
63
- )
64
- intention_output = intention_completion.choices[0].message['content']
65
-
66
- # Parse and route the intention
67
- parsed_task = parse_and_route(intention_output)
68
- selected_task = list(parsed_task.keys())[0]
69
-
70
- # Construct the prompt
71
- try:
72
- task_description = tasks[selected_task]
73
- except KeyError:
74
- print(f"Selected task not found: {selected_task}")
75
- return f"Error: Task '{selected_task}' not found. Please select a valid task."
76
-
77
- query_prefix = f"Instruct: {task_description}\nQuery: "
78
- queries = [input_text]
79
-
80
- # Get the embeddings
81
- with torch.no_grad():
82
- inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
83
- outputs = self.model(**inputs)
84
- query_embeddings = outputs.last_hidden_state.mean(dim=1)
85
-
86
- # Normalize embeddings
87
- query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
88
- embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
89
- self.clear_cuda_cache()
90
- return embeddings_list
91
-
92
-
93
- class MyEmbeddingFunction(EmbeddingFunction):
94
- def __init__(self, embedding_generator: EmbeddingGenerator):
95
- self.embedding_generator = embedding_generator
96
-
97
- def __call__(self, input: Documents) -> Embeddings:
98
- embeddings = [self.embedding_generator.compute_embeddings(doc) for doc in input]
99
- embeddings = [item for sublist in embeddings for item in sublist]
100
- return embeddings
101
  # main.py
102
  import os
103
  import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # main.py
2
  import os
3
  import uuid