TrumpChatUI / app.py
nawhgnuj's picture
Update app.py
61ced1b verified
raw
history blame
3.32 kB
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread
MODEL_LIST = ["nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = os.environ.get("MODEL_ID")
TITLE = "<h1 style='color: #E53935; text-align: center;'>Donald Trump Chatbot</h1>"
PLACEHOLDER = """
<div style='text-align: center;'>
<img src='https://upload.wikimedia.org/wikipedia/commons/5/56/Donald_Trump_official_portrait.jpg' style='width: 200px; border-radius: 50%;'>
<p style='color: #E53935; font-weight: bold;'>Hi! I'm Donald Trump!</p>
<p>Let's Make America Great Again! Ask me anything.</p>
</div>
"""
CSS = """
.chatbot {
background-color: #FFCDD2;
}
.duplicate-button {
margin: auto !important;
color: white !important;
background: #B71C1C !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
color: #E53935;
}
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config)
@spaces.GPU()
def stream_chat(
message: str,
history: list,
):
system_prompt = "You are a Donald Trump chatbot. You only answer like Trump in style and tone."
temperature = 0.8
max_new_tokens = 1024
top_p = 1.0
top_k = 20
penalty = 1.2
conversation = [
{"role": "system", "content": system_prompt}
]
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=[128001,128008,128009],
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, elem_classes="chatbot")
with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
gr.HTML(TITLE)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
examples=[
["What do you think about the economy?"],
["How would you handle foreign policy?"],
["What's your stance on immigration?"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()