jordigonzm commited on
Commit
e89bb60
1 Parent(s): 513fb7b

actualizacion

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -65,18 +65,20 @@ class StopOnTokens(StoppingCriteria):
65
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
66
  print(f'Mensaje: {message}')
67
  print(f'Historia: {history}')
68
-
69
- conversation = []
70
- for prompt, answer in history:
71
- conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
72
 
73
  stop = StopOnTokens()
74
 
 
75
  input_ids = tokenizer.encode(message, return_tensors='pt').to(next(model.parameters()).device)
 
76
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
77
 
78
  generate_kwargs = dict(
79
  input_ids=input_ids,
 
80
  streamer=streamer,
81
  max_new_tokens=max_new_tokens,
82
  do_sample=True,
@@ -84,19 +86,21 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
84
  temperature=temperature,
85
  repetition_penalty=1.1,
86
  stopping_criteria=StoppingCriteriaList([stop]),
87
- attention_mask=input_ids.ne(tokenizer.pad_token_id).long(), # Configurar máscara de atención
88
- pad_token_id=tokenizer.eos_token_id # Establecer pad_token_id al token de fin de secuencia
89
  )
90
 
 
91
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
92
  thread.start()
93
  buffer = ""
94
 
 
95
  for new_token in streamer:
96
  if new_token:
97
  buffer += new_token
98
- # Emitir el texto acumulado en un formato que Gradio espera: ["User message", "Bot response"]
99
- yield history + [[message, buffer]]
 
100
 
101
  # Configuración de la interfaz Gradio
102
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
 
65
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
66
  print(f'Mensaje: {message}')
67
  print(f'Historia: {history}')
68
+
69
+ # Limpieza de la historia para evitar pares con 'None'
70
+ cleaned_history = [[prompt, answer if answer is not None else ""] for prompt, answer in history]
 
71
 
72
  stop = StopOnTokens()
73
 
74
+ # Preparar los input_ids y manejar la máscara de atención
75
  input_ids = tokenizer.encode(message, return_tensors='pt').to(next(model.parameters()).device)
76
+ attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
77
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
78
 
79
  generate_kwargs = dict(
80
  input_ids=input_ids,
81
+ attention_mask=attention_mask, # Añadir máscara de atención
82
  streamer=streamer,
83
  max_new_tokens=max_new_tokens,
84
  do_sample=True,
 
86
  temperature=temperature,
87
  repetition_penalty=1.1,
88
  stopping_criteria=StoppingCriteriaList([stop]),
89
+ pad_token_id=tokenizer.eos_token_id # Establecer pad_token_id
 
90
  )
91
 
92
+ # Ejecutar la generación de tokens en un hilo separado
93
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
94
  thread.start()
95
  buffer = ""
96
 
97
+ # Procesar el streaming de tokens y formatear la respuesta para Gradio
98
  for new_token in streamer:
99
  if new_token:
100
  buffer += new_token
101
+ # Formatear la respuesta en un formato compatible con Gradio: [[Mensaje del usuario, Respuesta del bot]]
102
+ yield cleaned_history + [[message, buffer]]
103
+
104
 
105
  # Configuración de la interfaz Gradio
106
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)