File size: 3,323 Bytes
d18942e
 
 
 
 
 
 
 
61ced1b
d18942e
 
 
61ced1b
d18942e
 
61ced1b
 
 
 
 
d18942e
 
 
61ced1b
 
 
d18942e
 
 
61ced1b
d18942e
 
 
 
61ced1b
d18942e
 
 
61ced1b
d18942e
 
 
 
 
238ce74
d18942e
 
 
 
 
 
 
 
 
 
 
 
 
61ced1b
 
 
 
 
 
d18942e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61ced1b
 
 
 
 
d18942e
 
 
 
 
 
 
 
 
 
 
 
 
61ced1b
d18942e
61ced1b
d18942e
 
 
 
 
 
61ced1b
 
 
d18942e
 
 
 
 
61ced1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread

MODEL_LIST = ["nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = os.environ.get("MODEL_ID")

TITLE = "<h1 style='color: #E53935; text-align: center;'>Donald Trump Chatbot</h1>"

PLACEHOLDER = """
<div style='text-align: center;'>
<img src='https://upload.wikimedia.org/wikipedia/commons/5/56/Donald_Trump_official_portrait.jpg' style='width: 200px; border-radius: 50%;'>
<p style='color: #E53935; font-weight: bold;'>Hi! I'm Donald Trump!</p>
<p>Let's Make America Great Again! Ask me anything.</p>
</div>
"""

CSS = """
.chatbot {
    background-color: #FFCDD2;
}
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: #B71C1C !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
    color: #E53935;
}
"""

device = "cuda" if torch.cuda.is_available() else "cpu"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4")

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=quantization_config)

@spaces.GPU()
def stream_chat(
    message: str, 
    history: list,
):
    system_prompt = "You are a Donald Trump chatbot. You only answer like Trump in style and tone."
    temperature = 0.8
    max_new_tokens = 1024
    top_p = 1.0
    top_k = 20
    penalty = 1.2

    conversation = [
        {"role": "system", "content": system_prompt}
    ]
    for prompt, answer in history:
        conversation.extend([
            {"role": "user", "content": prompt}, 
            {"role": "assistant", "content": answer},
        ])

    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        input_ids=input_ids, 
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        eos_token_id=[128001,128008,128009],
        streamer=streamer,
    )

    with torch.no_grad():
        thread = Thread(target=model.generate, kwargs=generate_kwargs)
        thread.start()
        
    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer

chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, elem_classes="chatbot")

with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
    gr.HTML(TITLE)
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        examples=[
            ["What do you think about the economy?"],
            ["How would you handle foreign policy?"],
            ["What's your stance on immigration?"],
        ],
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch()