Yhhxhfh's picture
Update app.py
e2294f9 verified
raw
history blame
No virus
6.44 kB
from pydantic import BaseModel
from llama_cpp import Llama
import os
import gradio as gr # Not suitable for production
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import spaces
import asyncio
import random
#from llama_cpp.tokenizers import LlamaTokenizer
from peft import PeftModel, LoraConfig, get_peft_model
import torch
from multiprocessing import Process, Queue
from google.cloud import storage
import json
app = FastAPI()
load_dotenv()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
GOOGLE_CLOUD_BUCKET = os.getenv("GOOGLE_CLOUD_BUCKET")
GOOGLE_CLOUD_CREDENTIALS = os.getenv("GOOGLE_CLOUD_CREDENTIALS")
gcp_credentials = json.loads(GOOGLE_CLOUD_CREDENTIALS)
storage_client = storage.Client.from_service_account_info(gcp_credentials)
bucket = storage_client.bucket(GOOGLE_CLOUD_BUCKET)
MODEL_NAMES = {
"starcoder": "starcoder2-3b-q2_k.gguf",
"gemma_2b_it": "gemma-2-2b-it-q2_k.gguf",
"llama_3_2_1b": "Llama-3.2-1B.Q2_K.gguf",
"gemma_2b_imat": "gemma-2-2b-iq1_s-imat.gguf",
"phi_3_mini": "phi-3-mini-128k-instruct-iq2_xxs-imat.gguf",
"qwen2_0_5b": "qwen2-0.5b-iq1_s-imat.gguf",
"gemma_9b_it": "gemma-2-9b-it-q2_k.gguf",
"gpt2_xl": "gpt2-xl-q2_k.gguf",
}
class ModelManager:
def __init__(self):
self.params = {"n_ctx": 2048, "n_batch": 512, "n_predict": 512, "repeat_penalty": 1.1, "n_threads": 1, "seed": -1, "stop": ["</s>"], "tokens": []}
# self.tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") # Load from GCS for production
self.request_queue = Queue()
self.response_queue = Queue()
self.models = {} # Dictionary to hold multiple models
self.load_models()
self.start_processing_processes()
def load_model_from_bucket(self, bucket_path):
blob = bucket.blob(bucket_path)
try:
model = Llama(model_path=blob.download_as_string(), **self.params)
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
def load_models(self):
for name, path in MODEL_NAMES.items():
model = self.load_model_from_bucket(path)
if model:
self.models[name] = model
def save_model_to_bucket(self, model, bucket_path):
blob = bucket.blob(bucket_path)
try:
blob.upload_from_string(model.save_pretrained(), content_type='application/octet-stream')
except Exception as e:
print(f"Error saving model: {e}")
def train_model(self): #This function needs a complete overhaul for production use. This is a placeholder.
config = LoraConfig(r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
base_model_path = "llama-2-7b-chat/llama-2-7b-chat.Q4_K_M.gguf"
try:
base_model = self.load_model_from_bucket(base_model_path)
if base_model:
model = get_peft_model(base_model, config)
# Placeholder training data - needs a robust data loading mechanism
for batch in [{"question": ["a"], "answer":["b"]}, {"question":["c"], "answer":["d"]}]:
inputs = self.tokenizer(batch["question"], return_tensors="pt", padding=True, truncation=True)
labels = self.tokenizer(batch["answer"], return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs, labels=labels.input_ids)
loss = outputs.loss
loss.backward()
self.save_model_to_bucket(model, "llama_finetuned/llama_finetuned.gguf")
del model
del base_model
except Exception as e:
print(f"Error during training: {e}")
def generate_text(self, prompt, model_name):
if model_name in self.models:
model = self.models[model_name]
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
else:
return "Error: Model not found."
def start_processing_processes(self):
p = Process(target=self.process_requests)
p.start()
def process_requests(self):
while True:
request_data = self.request_queue.get()
if request_data is None:
break
inputs, model_name, top_p, top_k, temperature, max_tokens = request_data
try:
response = self.generate_text(inputs, model_name)
self.response_queue.put(response)
except Exception as e:
print(f"Error during inference: {e}")
self.response_queue.put("Error generating text.")
model_manager = ModelManager()
class ChatRequest(BaseModel):
message: str
model_name: str
@spaces.GPU()
async def generate_streaming_response(inputs, model_name):
top_p = 0.9
top_k = 50
temperature = 0.7
max_tokens = model_manager.params["n_ctx"] - len(model_manager.tokenizer.encode(inputs))
model_manager.request_queue.put((inputs, model_name, top_p, top_k, temperature, max_tokens))
full_text = model_manager.response_queue.get()
async def stream_response():
yield full_text
return StreamingResponse(stream_response())
async def process_message(message, model_name):
inputs = message.strip()
return await generate_streaming_response(inputs, model_name)
@app.post("/generate_multimodel")
async def api_generate_multimodel(request: Request):
data = await request.json()
message = data["message"]
model_name = data.get("model_name", list(MODEL_NAMES.keys())[0])
if model_name not in MODEL_NAMES:
return {"error": "Invalid model name"}
return await process_message(message, model_name)
iface = gr.Interface(fn=process_message, inputs=[gr.Textbox(lines=2, placeholder="Enter your message here..."), gr.Dropdown(list(MODEL_NAMES.keys()), label="Select Model")], outputs=gr.Markdown(stream=True), title="Unified Multi-Model API", description="Enter a message to get responses from a unified model.") #gradio is not suitable for production
if __name__ == "__main__":
iface.launch()