jordigonzm commited on
Commit
26fc4d9
1 Parent(s): fc439e1
Files changed (1) hide show
  1. app.py +125 -127
app.py CHANGED
@@ -1,151 +1,149 @@
1
- import subprocess
2
  import os
3
- import torch
4
  from threading import Thread
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
 
6
  import gradio as gr
7
- import time
8
  import spaces
 
 
9
 
10
- # Decorador para asegurar que la función se ejecute en la GPU
11
- @spaces.GPU
12
- def greet(n):
13
- zero = torch.tensor(0, device='cuda:0') # Inicializa el tensor en GPU
14
- print(zero.device) # Confirma que está en 'cuda:0'
15
- return f"Hello {zero.item() + n} Tensor"
16
-
17
- # Instalar dependencias necesarias sin recompilar CUDA
18
- subprocess.run(
19
- 'pip install flash-attn --no-build-isolation',
20
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
21
- shell=True
22
- )
23
 
24
- # Cargar el token desde las variables de entorno
25
- HF_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN", None)
26
-
27
- # Cargar el modelo y el tokenizador
28
- MODEL_NAME = "EleutherAI/gpt-j-6B"
29
- model = AutoModelForCausalLM.from_pretrained(
30
- MODEL_NAME,
31
- torch_dtype=torch.bfloat16,
32
- device_map="auto",
33
- trust_remote_code=True,
34
- use_auth_token=HF_TOKEN
35
- ).eval()
36
-
37
- tokenizer = AutoTokenizer.from_pretrained(
38
- MODEL_NAME,
39
- trust_remote_code=True,
40
- use_fast=False,
41
- use_auth_token=HF_TOKEN
42
- )
43
 
44
- # Ejecutar la función greet para prueba inicial
45
- print(greet(5)) # Ejemplo: Cambia 5 por el valor deseado para probar la función
46
-
47
- # Títulos y estilos para la interfaz
48
- TITLE = "<h1><center>Gemma Model Chat</center></h1>"
49
- PLACEHOLDER = f'<h3><center>{MODEL_NAME} es un modelo avanzado capaz de generar respuestas detalladas basadas en entradas complejas.</center></h3>'
50
- CSS = """
51
- .duplicate-button {
52
- margin: auto !important;
53
- color: white !important;
54
- background: black !important;
55
- border-radius: 100vh !important;
56
- }
57
  """
58
 
59
- # Definición de criterios de parada personalizados
60
- class StopOnTokens(StoppingCriteria):
61
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
62
- stop_ids = [tokenizer.eos_token_id]
63
- return any(input_ids[0][-1] == stop_id for stop_id in stop_ids)
64
 
65
- # Función para generar respuestas del chatbot con streaming
66
- def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
67
- print(f'Mensaje: {message}')
68
- print(f'Historia: {history}')
69
 
70
- # Limpieza del historial para evitar pares con 'None'
71
- cleaned_history = [[prompt, answer if answer is not None else ""] for prompt, answer in history]
72
 
73
- stop = StopOnTokens()
74
 
75
- # Verificar y asignar pad_token_id si es None
76
- if tokenizer.pad_token_id is None:
77
- tokenizer.pad_token_id = tokenizer.eos_token_id
 
 
78
 
79
- # Preparar los input_ids y manejar la máscara de atención
80
- input_ids = tokenizer.encode(message, return_tensors='pt').to(next(model.parameters()).device)
81
- attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
82
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  generate_kwargs = dict(
85
- input_ids=input_ids,
86
- attention_mask=attention_mask,
87
  streamer=streamer,
88
  max_new_tokens=max_new_tokens,
89
  do_sample=True,
90
- top_k=50,
 
91
  temperature=temperature,
92
- repetition_penalty=1.1,
93
- stopping_criteria=StoppingCriteriaList([stop]),
94
- pad_token_id=tokenizer.pad_token_id
95
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Ejecutar la generación de tokens en un hilo separado
98
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
99
- thread.start()
100
- buffer = ""
101
-
102
- # Procesar el streaming de tokens y formatear la respuesta para Gradio
103
- for new_token in streamer:
104
- if new_token:
105
- buffer += new_token
106
- # Asegúrate de que solo estás trabajando con texto puro
107
- buffer = buffer.strip() # Eliminar espacios innecesarios
108
- # Emitir el texto acumulado en un formato compatible con Gradio: [[Mensaje del usuario, Respuesta del bot]]
109
- yield cleaned_history + [[message, buffer]]
110
-
111
-
112
- # Configuración de la interfaz Gradio
113
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
114
-
115
- with gr.Blocks(css=CSS) as demo:
116
- gr.HTML(TITLE)
117
- gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
118
- gr.ChatInterface(
119
- fn=stream_chat,
120
- chatbot=chatbot,
121
- fill_height=True,
122
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
123
- additional_inputs=[
124
- gr.Slider(
125
- minimum=0,
126
- maximum=1,
127
- step=0.1,
128
- value=0.5,
129
- label="Temperature",
130
- render=False,
131
- ),
132
- gr.Slider(
133
- minimum=1024,
134
- maximum=32768,
135
- step=1,
136
- value=4096,
137
- label="Max New Tokens",
138
- render=False,
139
- ),
140
- ],
141
- examples=[
142
- ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
143
- ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
144
- ["Tell me a random fun fact about the Roman Empire."],
145
- ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
146
- ],
147
- cache_examples=False,
148
- )
149
 
150
  if __name__ == "__main__":
151
- demo.launch()
 
 
 
1
  import os
 
2
  from threading import Thread
3
+ from typing import Iterator
4
+
5
  import gradio as gr
 
6
  import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
+ MAX_MAX_NEW_TOKENS = 2048
11
+ DEFAULT_MAX_NEW_TOKENS = 1024
12
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
 
 
 
 
 
 
 
13
 
14
+ DESCRIPTION = """\
15
+ # Llama-2 13B Chat
16
+
17
+ This Space demonstrates model [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta, a Llama 2 model with 13B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
18
+
19
+ 🔎 For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2).
20
+
21
+ 🔨 Looking for an even more powerful model? Check out the large [**70B** model demo](https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI).
22
+ 🐇 For a smaller model that you can run on many GPUs, check our [7B model demo](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat).
 
 
 
 
 
 
 
 
 
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  """
25
 
26
+ LICENSE = """
27
+ <p/>
 
 
 
28
 
29
+ ---
30
+ As a derivate work of [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta,
31
+ this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/USE_POLICY.md).
32
+ """
33
 
34
+ if not torch.cuda.is_available():
35
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
36
 
 
37
 
38
+ if torch.cuda.is_available():
39
+ model_id = "meta-llama/Llama-2-13b-chat-hf"
40
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
41
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
42
+ tokenizer.use_default_system_prompt = False
43
 
 
 
 
 
44
 
45
+ @spaces.GPU
46
+ def generate(
47
+ message: str,
48
+ chat_history: list[tuple[str, str]],
49
+ system_prompt: str,
50
+ max_new_tokens: int = 1024,
51
+ temperature: float = 0.6,
52
+ top_p: float = 0.9,
53
+ top_k: int = 50,
54
+ repetition_penalty: float = 1.2,
55
+ ) -> Iterator[str]:
56
+ conversation = []
57
+ if system_prompt:
58
+ conversation.append({"role": "system", "content": system_prompt})
59
+ for user, assistant in chat_history:
60
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
61
+ conversation.append({"role": "user", "content": message})
62
+
63
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
64
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
65
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
66
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
67
+ input_ids = input_ids.to(model.device)
68
+
69
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
70
  generate_kwargs = dict(
71
+ {"input_ids": input_ids},
 
72
  streamer=streamer,
73
  max_new_tokens=max_new_tokens,
74
  do_sample=True,
75
+ top_p=top_p,
76
+ top_k=top_k,
77
  temperature=temperature,
78
+ num_beams=1,
79
+ repetition_penalty=repetition_penalty,
 
80
  )
81
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
82
+ t.start()
83
+
84
+ outputs = []
85
+ for text in streamer:
86
+ outputs.append(text)
87
+ yield "".join(outputs)
88
+
89
+
90
+ chat_interface = gr.ChatInterface(
91
+ fn=generate,
92
+ additional_inputs=[
93
+ gr.Textbox(label="System prompt", lines=6),
94
+ gr.Slider(
95
+ label="Max new tokens",
96
+ minimum=1,
97
+ maximum=MAX_MAX_NEW_TOKENS,
98
+ step=1,
99
+ value=DEFAULT_MAX_NEW_TOKENS,
100
+ ),
101
+ gr.Slider(
102
+ label="Temperature",
103
+ minimum=0.1,
104
+ maximum=4.0,
105
+ step=0.1,
106
+ value=0.6,
107
+ ),
108
+ gr.Slider(
109
+ label="Top-p (nucleus sampling)",
110
+ minimum=0.05,
111
+ maximum=1.0,
112
+ step=0.05,
113
+ value=0.9,
114
+ ),
115
+ gr.Slider(
116
+ label="Top-k",
117
+ minimum=1,
118
+ maximum=1000,
119
+ step=1,
120
+ value=50,
121
+ ),
122
+ gr.Slider(
123
+ label="Repetition penalty",
124
+ minimum=1.0,
125
+ maximum=2.0,
126
+ step=0.05,
127
+ value=1.2,
128
+ ),
129
+ ],
130
+ stop_btn=None,
131
+ examples=[
132
+ ["Hello there! How are you doing?"],
133
+ ["Can you explain briefly to me what is the Python programming language?"],
134
+ ["Explain the plot of Cinderella in a sentence."],
135
+ ["How many hours does it take a man to eat a Helicopter?"],
136
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
137
+ ],
138
+ cache_examples=False,
139
+ )
140
 
141
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
142
+ gr.Markdown(DESCRIPTION)
143
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
144
+ chat_interface.render()
145
+ gr.Markdown(LICENSE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
+ demo.queue(max_size=20).launch()
149
+