File size: 4,508 Bytes
524455a
bf4511d
3fbcfce
524455a
 
 
 
 
 
 
 
 
 
 
 
8b47d6f
859152b
50c763f
859152b
 
 
 
 
 
 
 
 
 
 
 
 
 
524455a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21ff7eb
524455a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45a2312
ea13124
 
524455a
ea13124
524455a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33a1100
524455a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import subprocess
import os
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
import gradio as gr

# Instalar dependencias necesarias
subprocess.run(
    'pip install flash-attn --no-build-isolation',
    env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
    shell=True
)

# Cargar el token desde las variables de entorno
HF_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN", None)

MODEL_NAME = "EleutherAI/gpt-j-6B"
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    token=HF_TOKEN
).eval()

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, 
    trust_remote_code=True, 
    use_fast=False, 
    token=HF_TOKEN
)

# Títulos y estilos para la interfaz
TITLE = "<h1><center>Gemma Model Chat</center></h1>"
PLACEHOLDER = f'<h3><center>{MODEL_NAME} es un modelo avanzado capaz de generar respuestas detalladas basadas en entradas complejas.</center></h3>'
CSS = """
.duplicate-button {
  margin: auto !important;
  color: white !important;
  background: black !important;
  border-radius: 100vh !important;
}
"""

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [tokenizer.eos_token_id]
        return any(input_ids[0][-1] == stop_id for stop_id in stop_ids)

def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
    print(f'Mensaje: {message}')
    print(f'Historia: {history}')
    
    conversation = []
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])

    stop = StopOnTokens()

    input_ids = tokenizer.encode(message, return_tensors='pt').to(next(model.parameters()).device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_k=50,
        temperature=temperature,
        repetition_penalty=1.1,
        stopping_criteria=StoppingCriteriaList([stop]),
    )

    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()
    buffer = ""

    for new_token in streamer:
        if new_token:
            buffer += new_token
            # Emitir el resultado en un formato compatible con OpenAI
            yield {
                "choices": [
                    {
                        "text": buffer,
                        "index": 0,
                        "logprobs": None,
                        "finish_reason": "stop"
                    }
                ],
                "id": "req-12345",  # Reemplazar con un ID único si es necesario
                "model": MODEL_NAME,
                "created": int(time.time())
            }

# Configuración de la interfaz Gradio
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)

with gr.Blocks(css=CSS) as demo:
    gr.HTML(TITLE)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.5,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=1024,
                maximum=32768,
                step=1,
                value=4096,
                label="Max New Tokens",
                render=False,
            ),
        ],
        examples=[
            ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
            ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
            ["Tell me a random fun fact about the Roman Empire."],
            ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
        ],
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch()