Llama-2-13B / app.py
jordigonzm's picture
import time
17f5cbe
raw
history blame
4.52 kB
import subprocess
import os
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
import gradio as gr
import time
# Instalar dependencias necesarias
subprocess.run(
'pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True
)
# Cargar el token desde las variables de entorno
HF_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN", None)
MODEL_NAME = "EleutherAI/gpt-j-6B"
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
token=HF_TOKEN
).eval()
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
use_fast=False,
token=HF_TOKEN
)
# Títulos y estilos para la interfaz
TITLE = "<h1><center>Gemma Model Chat</center></h1>"
PLACEHOLDER = f'<h3><center>{MODEL_NAME} es un modelo avanzado capaz de generar respuestas detalladas basadas en entradas complejas.</center></h3>'
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [tokenizer.eos_token_id]
return any(input_ids[0][-1] == stop_id for stop_id in stop_ids)
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
print(f'Mensaje: {message}')
print(f'Historia: {history}')
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
stop = StopOnTokens()
input_ids = tokenizer.encode(message, return_tensors='pt').to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_k=50,
temperature=temperature,
repetition_penalty=1.1,
stopping_criteria=StoppingCriteriaList([stop]),
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_token in streamer:
if new_token:
buffer += new_token
# Emitir el resultado en un formato compatible con OpenAI
yield {
"choices": [
{
"text": buffer,
"index": 0,
"logprobs": None,
"finish_reason": "stop"
}
],
"id": "req-12345", # Reemplazar con un ID único si es necesario
"model": MODEL_NAME,
"created": int(time.time())
}
# Configuración de la interfaz Gradio
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.5,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=32768,
step=1,
value=4096,
label="Max New Tokens",
render=False,
),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()