Llama-2-13B / app.py
jordigonzm's picture
actualizacion
e89bb60
raw
history blame
5.15 kB
import subprocess
import os
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
import gradio as gr
import time
# Decorador para asegurar que la funci贸n se ejecute en la GPU
@spaces.GPU
def greet(n):
zero = torch.tensor(0, device='cuda:0') # Inicializa el tensor en GPU
print(zero.device) # Confirma que est谩 en 'cuda:0'
return f"Hello {zero.item() + n} Tensor"
# Instalar dependencias necesarias sin recompilar CUDA
subprocess.run(
'pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True
)
# Cargar el token desde las variables de entorno
HF_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN", None)
# Cargar el modelo y el tokenizador
MODEL_NAME = "EleutherAI/gpt-j-6B"
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
use_auth_token=HF_TOKEN
).eval()
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
use_fast=False,
use_auth_token=HF_TOKEN
)
# Ejecutar la funci贸n greet para prueba inicial
print(greet(5)) # Ejemplo: Cambia 5 por el valor deseado para probar la funci贸n
# T铆tulos y estilos para la interfaz
TITLE = "<h1><center>Gemma Model Chat</center></h1>"
PLACEHOLDER = f'<h3><center>{MODEL_NAME} es un modelo avanzado capaz de generar respuestas detalladas basadas en entradas complejas.</center></h3>'
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
# Definici贸n de criterios de parada personalizados
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [tokenizer.eos_token_id]
return any(input_ids[0][-1] == stop_id for stop_id in stop_ids)
# Funci贸n para generar respuestas del chatbot con streaming
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
print(f'Mensaje: {message}')
print(f'Historia: {history}')
# Limpieza de la historia para evitar pares con 'None'
cleaned_history = [[prompt, answer if answer is not None else ""] for prompt, answer in history]
stop = StopOnTokens()
# Preparar los input_ids y manejar la m谩scara de atenci贸n
input_ids = tokenizer.encode(message, return_tensors='pt').to(next(model.parameters()).device)
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask, # A帽adir m谩scara de atenci贸n
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_k=50,
temperature=temperature,
repetition_penalty=1.1,
stopping_criteria=StoppingCriteriaList([stop]),
pad_token_id=tokenizer.eos_token_id # Establecer pad_token_id
)
# Ejecutar la generaci贸n de tokens en un hilo separado
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
# Procesar el streaming de tokens y formatear la respuesta para Gradio
for new_token in streamer:
if new_token:
buffer += new_token
# Formatear la respuesta en un formato compatible con Gradio: [[Mensaje del usuario, Respuesta del bot]]
yield cleaned_history + [[message, buffer]]
# Configuraci贸n de la interfaz Gradio
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="鈿欙笍 Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.5,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=32768,
step=1,
value=4096,
label="Max New Tokens",
render=False,
),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()