jordigonzm commited on
Commit
524455a
1 Parent(s): 3fbcfce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -47
app.py CHANGED
@@ -1,55 +1,132 @@
 
1
  import os
2
- from transformers import pipeline
3
  import torch
4
- import time
5
-
6
- # Cargar el token desde la variable de entorno
7
- token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
8
- model_name = "google/gemma-2-27b-it"
9
-
10
- try:
11
- generator = pipeline(
12
- "text-generation",
13
- model=model_name,
14
- device=0 if torch.cuda.is_available() else -1,
15
- use_auth_token=token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
17
- print("Modelo cargado correctamente.")
18
- except Exception as e:
19
- print(f"Error al cargar el modelo: {e}")
20
- exit(1)
21
-
22
- # Función para procesar la entrada y generar la respuesta
23
- def generate_response(text):
24
- output = generator(text, max_length=512, num_return_sequences=1)
25
- response = {
26
- "choices": [
27
- {
28
- "text": output[0]['generated_text'],
29
- "index": 0,
30
- "logprobs": None,
31
- "finish_reason": "stop"
 
 
 
 
 
 
32
  }
33
- ],
34
- "id": "req-12345", # Reemplazar con un ID único
35
- "model": model_name,
36
- "created": int(time.time())
37
- }
38
- return response
39
 
40
  # Configuración de la interfaz Gradio
41
- import gradio as gr
42
 
43
- iface = gr.Interface(
44
- fn=generate_response,
45
- inputs="text",
46
- outputs="json",
47
- title="API compatible con OpenAI",
48
- description="Introduce texto para obtener una respuesta del modelo Meta-Llama-3-8B-Instruct."
49
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Ejecutar la interfaz
52
- try:
53
- iface.launch()
54
- except Exception as e:
55
- print(f"Error al iniciar la interfaz: {e}")
 
1
+ import subprocess
2
  import os
 
3
  import torch
4
+ from threading import Thread
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
6
+ import gradio as gr
7
+
8
+ # Instalar dependencias necesarias
9
+ subprocess.run(
10
+ 'pip install flash-attn --no-build-isolation',
11
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
12
+ shell=True
13
+ )
14
+
15
+ # Cargar el token desde las variables de entorno
16
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
+ MODEL_NAME = "google/gemma-2-27b-it"
18
+
19
+ # Títulos y estilos para la interfaz
20
+ TITLE = "<h1><center>Gemma Model Chat</center></h1>"
21
+ PLACEHOLDER = f'<h3><center>{MODEL_NAME} es un modelo avanzado capaz de generar respuestas detalladas basadas en entradas complejas.</center></h3>'
22
+ CSS = """
23
+ .duplicate-button {
24
+ margin: auto !important;
25
+ color: white !important;
26
+ background: black !important;
27
+ border-radius: 100vh !important;
28
+ }
29
+ """
30
+
31
+ # Cargar el modelo y el tokenizador
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ MODEL_NAME,
34
+ torch_dtype=torch.bfloat16,
35
+ device_map="auto",
36
+ trust_remote_code=True,
37
+ use_auth_token=HF_TOKEN
38
+ ).eval()
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True, use_fast=False)
41
+
42
+ class StopOnTokens(StoppingCriteria):
43
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
44
+ stop_ids = [tokenizer.eos_token_id]
45
+ return any(input_ids[0][-1] == stop_id for stop_id in stop_ids)
46
+
47
+ def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
48
+ print(f'Mensaje: {message}')
49
+ print(f'Historia: {history}')
50
+
51
+ conversation = []
52
+ for prompt, answer in history:
53
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
54
+
55
+ stop = StopOnTokens()
56
+
57
+ input_ids = tokenizer.encode(message, return_tensors='pt').to(next(model.parameters()).device)
58
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
59
+
60
+ generate_kwargs = dict(
61
+ input_ids=input_ids,
62
+ streamer=streamer,
63
+ max_new_tokens=max_new_tokens,
64
+ do_sample=True,
65
+ top_k=50,
66
+ temperature=temperature,
67
+ repetition_penalty=1.1,
68
+ stopping_criteria=StoppingCriteriaList([stop]),
69
  )
70
+
71
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
72
+ thread.start()
73
+ buffer = ""
74
+
75
+ for new_token in streamer:
76
+ if new_token:
77
+ buffer += new_token
78
+ # Emitir el resultado en un formato compatible con OpenAI
79
+ yield {
80
+ "choices": [
81
+ {
82
+ "text": buffer,
83
+ "index": 0,
84
+ "logprobs": None,
85
+ "finish_reason": "stop"
86
+ }
87
+ ],
88
+ "id": "req-12345", # Reemplazar con un ID único si es necesario
89
+ "model": MODEL_NAME,
90
+ "created": int(time.time())
91
  }
 
 
 
 
 
 
92
 
93
  # Configuración de la interfaz Gradio
94
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
95
 
96
+ with gr.Blocks(css=CSS) as demo:
97
+ gr.HTML(TITLE)
98
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
99
+ gr.ChatInterface(
100
+ fn=stream_chat,
101
+ chatbot=chatbot,
102
+ fill_height=True,
103
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
104
+ additional_inputs=[
105
+ gr.Slider(
106
+ minimum=0,
107
+ maximum=1,
108
+ step=0.1,
109
+ value=0.5,
110
+ label="Temperature",
111
+ render=False,
112
+ ),
113
+ gr.Slider(
114
+ minimum=1024,
115
+ maximum=32768,
116
+ step=1,
117
+ value=4096,
118
+ label="Max New Tokens",
119
+ render=False,
120
+ ),
121
+ ],
122
+ examples=[
123
+ ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
124
+ ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
125
+ ["Tell me a random fun fact about the Roman Empire."],
126
+ ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
127
+ ],
128
+ cache_examples=False,
129
+ )
130
 
131
+ if __name__ == "__main__":
132
+ demo.launch()