Yhhxhfh commited on
Commit
3e937fb
1 Parent(s): 09d0127

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -64
app.py CHANGED
@@ -1,62 +1,102 @@
1
  from pydantic import BaseModel
2
  from llama_cpp import Llama
3
  import os
4
- import gradio as gr
5
  from dotenv import load_dotenv
6
  from fastapi import FastAPI, Request
7
  from fastapi.responses import StreamingResponse
8
  import spaces
9
  import asyncio
10
  import random
11
- from io import BytesIO
12
- import requests
 
 
 
 
13
 
14
  app = FastAPI()
15
  load_dotenv()
16
 
17
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
18
 
19
  class ModelManager:
20
  def __init__(self):
21
- self.params = {
22
- "n_ctx": 2048,
23
- "n_batch": 512,
24
- "n_predict": 512,
25
- "repeat_penalty": 1.1,
26
- "n_threads": int(os.cpu_count() * 0.75),
27
- "seed": -1,
28
- "stop": ["</s>"],
29
- "tokens": [],
30
- }
31
- self.unified_model = self.load_unified_model()
32
-
33
- def load_unified_model(self):
34
- model_configs = [
35
- {
36
- "repo_id": "TheBloke/Llama-2-7B-Chat-GGUF",
37
- "filename": "llama-2-7b-chat.Q4_K_M.gguf",
38
- },
39
- ]
40
-
41
- models = []
42
- for config in model_configs:
43
- with BytesIO() as model_data:
44
- download_url = f"https://huggingface.co/{config['repo_id']}/resolve/main/{config['filename']}"
45
- response = requests.get(download_url, headers={"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"}, stream=True)
46
-
47
- for chunk in response.iter_content(chunk_size=1024*1024):
48
- if chunk:
49
- model_data.write(chunk)
50
-
51
- model_data.seek(0)
52
-
53
- model = Llama(model_path="", model_data=model_data.read(), **self.params)
54
- models.append(model)
55
-
56
- self.params["tokens"] = models[0].tokenize(b"Hello")
57
-
58
- self.unified_model = models[0]
59
- return self.unified_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  model_manager = ModelManager()
62
 
@@ -65,22 +105,14 @@ class ChatRequest(BaseModel):
65
 
66
  @spaces.GPU()
67
  async def generate_streaming_response(inputs):
68
- top_p = round(random.uniform(0.01, 1.00), 2)
69
- top_k = random.randint(1, 100)
70
- temperature = round(random.uniform(0.01, 2.00), 2)
71
- max_tokens = model_manager.params["n_ctx"] - len(model_manager.unified_model.tokenize(inputs))
72
-
 
73
  async def stream_response():
74
- response = await asyncio.to_thread(model_manager.unified_model, inputs, top_p=top_p, top_k=top_k, temperature=temperature, max_tokens=max_tokens, **model_manager.params)
75
- full_text = response['choices'][0]['text']
76
-
77
- if len(full_text) > max_tokens:
78
- chunks = [full_text[i:i + max_tokens] for i in range(0, len(full_text), max_tokens)]
79
- for chunk in chunks:
80
- yield chunk
81
- else:
82
- yield full_text
83
-
84
  return StreamingResponse(stream_response())
85
 
86
  async def process_message(message):
@@ -93,13 +125,7 @@ async def api_generate_multimodel(request: Request):
93
  message = data["message"]
94
  return await process_message(message)
95
 
96
- iface = gr.Interface(
97
- fn=process_message,
98
- inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."),
99
- outputs=gr.Markdown(stream=True),
100
- title="Unified Multi-Model API",
101
- description="Enter a message to get responses from a unified model."
102
- )
103
 
104
  if __name__ == "__main__":
105
  iface.launch()
 
1
  from pydantic import BaseModel
2
  from llama_cpp import Llama
3
  import os
4
+ import gradio as gr # Not suitable for production
5
  from dotenv import load_dotenv
6
  from fastapi import FastAPI, Request
7
  from fastapi.responses import StreamingResponse
8
  import spaces
9
  import asyncio
10
  import random
11
+ from llama_cpp.tokenizers import LlamaTokenizer
12
+ from peft import PeftModel, LoraConfig, get_peft_model
13
+ import torch
14
+ from multiprocessing import Process, Queue
15
+ from google.cloud import storage
16
+ import json
17
 
18
  app = FastAPI()
19
  load_dotenv()
20
 
21
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
22
+ GOOGLE_CLOUD_BUCKET = os.getenv("GOOGLE_CLOUD_BUCKET")
23
+ GOOGLE_CLOUD_CREDENTIALS = os.getenv("GOOGLE_CLOUD_CREDENTIALS")
24
+
25
+ gcp_credentials = json.loads(GOOGLE_CLOUD_CREDENTIALS)
26
+ storage_client = storage.Client.from_service_account_info(gcp_credentials)
27
+ bucket = storage_client.bucket(GOOGLE_CLOUD_BUCKET)
28
 
29
  class ModelManager:
30
  def __init__(self):
31
+ self.params = {"n_ctx": 2048, "n_batch": 512, "n_predict": 512, "repeat_penalty": 1.1, "n_threads": 1, "seed": -1, "stop": ["</s>"], "tokens": []}
32
+ self.tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") #Load tokenizer from GCS for production
33
+ self.request_queue = Queue()
34
+ self.response_queue = Queue()
35
+ self.model = self.load_model_from_bucket("llama-2-7b-chat/llama-2-7b-chat.Q4_K_M.gguf")
36
+ self.start_processing_processes()
37
+
38
+ def load_model_from_bucket(self, bucket_path):
39
+ blob = bucket.blob(bucket_path)
40
+ try:
41
+ model = Llama(model_path=blob.download_as_string(), **self.params)
42
+ return model
43
+ except Exception as e:
44
+ print(f"Error loading model: {e}")
45
+ return None
46
+
47
+ def save_model_to_bucket(self, model, bucket_path):
48
+ blob = bucket.blob(bucket_path)
49
+ try:
50
+ blob.upload_from_string(model.save_pretrained(), content_type='application/octet-stream')
51
+ except Exception as e:
52
+ print(f"Error saving model: {e}")
53
+
54
+ def train_model(self): #This function needs a complete overhaul for production use. This is a placeholder.
55
+ config = LoraConfig(r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
56
+ base_model_path = "llama-2-7b-chat/llama-2-7b-chat.Q4_K_M.gguf"
57
+ try:
58
+ base_model = self.load_model_from_bucket(base_model_path)
59
+ if base_model:
60
+ model = get_peft_model(base_model, config)
61
+ # Placeholder training data - needs a robust data loading mechanism
62
+ for batch in [{"question": ["a"], "answer":["b"]}, {"question":["c"], "answer":["d"]}]:
63
+ inputs = self.tokenizer(batch["question"], return_tensors="pt", padding=True, truncation=True)
64
+ labels = self.tokenizer(batch["answer"], return_tensors="pt", padding=True, truncation=True)
65
+ outputs = model(**inputs, labels=labels.input_ids)
66
+ loss = outputs.loss
67
+ loss.backward()
68
+ self.save_model_to_bucket(model, "llama_finetuned/llama_finetuned.gguf")
69
+ del model
70
+ del base_model
71
+ except Exception as e:
72
+ print(f"Error during training: {e}")
73
+
74
+
75
+ def generate_text(self, prompt):
76
+ if self.model:
77
+ inputs = self.tokenizer(prompt, return_tensors="pt")
78
+ outputs = self.model.generate(**inputs, max_new_tokens=100)
79
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+ return generated_text
81
+ else:
82
+ return "Error loading model."
83
+
84
+ def start_processing_processes(self):
85
+ p = Process(target=self.process_requests)
86
+ p.start()
87
+
88
+ def process_requests(self):
89
+ while True:
90
+ request_data = self.request_queue.get()
91
+ if request_data is None:
92
+ break
93
+ inputs, top_p, top_k, temperature, max_tokens = request_data
94
+ try:
95
+ response = self.generate_text(inputs)
96
+ self.response_queue.put(response)
97
+ except Exception as e:
98
+ print(f"Error during inference: {e}")
99
+ self.response_queue.put("Error generating text.")
100
 
101
  model_manager = ModelManager()
102
 
 
105
 
106
  @spaces.GPU()
107
  async def generate_streaming_response(inputs):
108
+ top_p = 0.9
109
+ top_k = 50
110
+ temperature = 0.7
111
+ max_tokens = model_manager.params["n_ctx"] - len(model_manager.tokenizer.encode(inputs))
112
+ model_manager.request_queue.put((inputs, top_p, top_k, temperature, max_tokens))
113
+ full_text = model_manager.response_queue.get()
114
  async def stream_response():
115
+ yield full_text
 
 
 
 
 
 
 
 
 
116
  return StreamingResponse(stream_response())
117
 
118
  async def process_message(message):
 
125
  message = data["message"]
126
  return await process_message(message)
127
 
128
+ iface = gr.Interface(fn=process_message, inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."), outputs=gr.Markdown(stream=True), title="Unified Multi-Model API", description="Enter a message to get responses from a unified model.") #gradio is not suitable for production
 
 
 
 
 
 
129
 
130
  if __name__ == "__main__":
131
  iface.launch()