Yhhxhfh commited on
Commit
e2294f9
1 Parent(s): bcc2214

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -14
app.py CHANGED
@@ -26,13 +26,25 @@ gcp_credentials = json.loads(GOOGLE_CLOUD_CREDENTIALS)
26
  storage_client = storage.Client.from_service_account_info(gcp_credentials)
27
  bucket = storage_client.bucket(GOOGLE_CLOUD_BUCKET)
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  class ModelManager:
30
  def __init__(self):
31
  self.params = {"n_ctx": 2048, "n_batch": 512, "n_predict": 512, "repeat_penalty": 1.1, "n_threads": 1, "seed": -1, "stop": ["</s>"], "tokens": []}
32
- # self.tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") #Load tokenizer from GCS for production
33
  self.request_queue = Queue()
34
  self.response_queue = Queue()
35
- self.model = self.load_model_from_bucket("llama-2-7b-chat/llama-2-7b-chat.Q4_K_M.gguf")
 
36
  self.start_processing_processes()
37
 
38
  def load_model_from_bucket(self, bucket_path):
@@ -44,6 +56,12 @@ class ModelManager:
44
  print(f"Error loading model: {e}")
45
  return None
46
 
 
 
 
 
 
 
47
  def save_model_to_bucket(self, model, bucket_path):
48
  blob = bucket.blob(bucket_path)
49
  try:
@@ -72,14 +90,15 @@ class ModelManager:
72
  print(f"Error during training: {e}")
73
 
74
 
75
- def generate_text(self, prompt):
76
- if self.model:
 
77
  inputs = self.tokenizer(prompt, return_tensors="pt")
78
- outputs = self.model.generate(**inputs, max_new_tokens=100)
79
  generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
80
  return generated_text
81
  else:
82
- return "Error loading model."
83
 
84
  def start_processing_processes(self):
85
  p = Process(target=self.process_requests)
@@ -90,9 +109,9 @@ class ModelManager:
90
  request_data = self.request_queue.get()
91
  if request_data is None:
92
  break
93
- inputs, top_p, top_k, temperature, max_tokens = request_data
94
  try:
95
- response = self.generate_text(inputs)
96
  self.response_queue.put(response)
97
  except Exception as e:
98
  print(f"Error during inference: {e}")
@@ -102,30 +121,34 @@ model_manager = ModelManager()
102
 
103
  class ChatRequest(BaseModel):
104
  message: str
 
105
 
106
  @spaces.GPU()
107
- async def generate_streaming_response(inputs):
108
  top_p = 0.9
109
  top_k = 50
110
  temperature = 0.7
111
  max_tokens = model_manager.params["n_ctx"] - len(model_manager.tokenizer.encode(inputs))
112
- model_manager.request_queue.put((inputs, top_p, top_k, temperature, max_tokens))
113
  full_text = model_manager.response_queue.get()
114
  async def stream_response():
115
  yield full_text
116
  return StreamingResponse(stream_response())
117
 
118
- async def process_message(message):
119
  inputs = message.strip()
120
- return await generate_streaming_response(inputs)
121
 
122
  @app.post("/generate_multimodel")
123
  async def api_generate_multimodel(request: Request):
124
  data = await request.json()
125
  message = data["message"]
126
- return await process_message(message)
 
 
 
127
 
128
- iface = gr.Interface(fn=process_message, inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."), outputs=gr.Markdown(stream=True), title="Unified Multi-Model API", description="Enter a message to get responses from a unified model.") #gradio is not suitable for production
129
 
130
  if __name__ == "__main__":
131
  iface.launch()
 
26
  storage_client = storage.Client.from_service_account_info(gcp_credentials)
27
  bucket = storage_client.bucket(GOOGLE_CLOUD_BUCKET)
28
 
29
+ MODEL_NAMES = {
30
+ "starcoder": "starcoder2-3b-q2_k.gguf",
31
+ "gemma_2b_it": "gemma-2-2b-it-q2_k.gguf",
32
+ "llama_3_2_1b": "Llama-3.2-1B.Q2_K.gguf",
33
+ "gemma_2b_imat": "gemma-2-2b-iq1_s-imat.gguf",
34
+ "phi_3_mini": "phi-3-mini-128k-instruct-iq2_xxs-imat.gguf",
35
+ "qwen2_0_5b": "qwen2-0.5b-iq1_s-imat.gguf",
36
+ "gemma_9b_it": "gemma-2-9b-it-q2_k.gguf",
37
+ "gpt2_xl": "gpt2-xl-q2_k.gguf",
38
+ }
39
+
40
  class ModelManager:
41
  def __init__(self):
42
  self.params = {"n_ctx": 2048, "n_batch": 512, "n_predict": 512, "repeat_penalty": 1.1, "n_threads": 1, "seed": -1, "stop": ["</s>"], "tokens": []}
43
+ # self.tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") # Load from GCS for production
44
  self.request_queue = Queue()
45
  self.response_queue = Queue()
46
+ self.models = {} # Dictionary to hold multiple models
47
+ self.load_models()
48
  self.start_processing_processes()
49
 
50
  def load_model_from_bucket(self, bucket_path):
 
56
  print(f"Error loading model: {e}")
57
  return None
58
 
59
+ def load_models(self):
60
+ for name, path in MODEL_NAMES.items():
61
+ model = self.load_model_from_bucket(path)
62
+ if model:
63
+ self.models[name] = model
64
+
65
  def save_model_to_bucket(self, model, bucket_path):
66
  blob = bucket.blob(bucket_path)
67
  try:
 
90
  print(f"Error during training: {e}")
91
 
92
 
93
+ def generate_text(self, prompt, model_name):
94
+ if model_name in self.models:
95
+ model = self.models[model_name]
96
  inputs = self.tokenizer(prompt, return_tensors="pt")
97
+ outputs = model.generate(**inputs, max_new_tokens=100)
98
  generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
99
  return generated_text
100
  else:
101
+ return "Error: Model not found."
102
 
103
  def start_processing_processes(self):
104
  p = Process(target=self.process_requests)
 
109
  request_data = self.request_queue.get()
110
  if request_data is None:
111
  break
112
+ inputs, model_name, top_p, top_k, temperature, max_tokens = request_data
113
  try:
114
+ response = self.generate_text(inputs, model_name)
115
  self.response_queue.put(response)
116
  except Exception as e:
117
  print(f"Error during inference: {e}")
 
121
 
122
  class ChatRequest(BaseModel):
123
  message: str
124
+ model_name: str
125
 
126
  @spaces.GPU()
127
+ async def generate_streaming_response(inputs, model_name):
128
  top_p = 0.9
129
  top_k = 50
130
  temperature = 0.7
131
  max_tokens = model_manager.params["n_ctx"] - len(model_manager.tokenizer.encode(inputs))
132
+ model_manager.request_queue.put((inputs, model_name, top_p, top_k, temperature, max_tokens))
133
  full_text = model_manager.response_queue.get()
134
  async def stream_response():
135
  yield full_text
136
  return StreamingResponse(stream_response())
137
 
138
+ async def process_message(message, model_name):
139
  inputs = message.strip()
140
+ return await generate_streaming_response(inputs, model_name)
141
 
142
  @app.post("/generate_multimodel")
143
  async def api_generate_multimodel(request: Request):
144
  data = await request.json()
145
  message = data["message"]
146
+ model_name = data.get("model_name", list(MODEL_NAMES.keys())[0])
147
+ if model_name not in MODEL_NAMES:
148
+ return {"error": "Invalid model name"}
149
+ return await process_message(message, model_name)
150
 
151
+ iface = gr.Interface(fn=process_message, inputs=[gr.Textbox(lines=2, placeholder="Enter your message here..."), gr.Dropdown(list(MODEL_NAMES.keys()), label="Select Model")], outputs=gr.Markdown(stream=True), title="Unified Multi-Model API", description="Enter a message to get responses from a unified model.") #gradio is not suitable for production
152
 
153
  if __name__ == "__main__":
154
  iface.launch()