Spaces:
Paused
Paused
import os | |
import time | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
import gradio as gr | |
from threading import Thread | |
MODEL_LIST = ["nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat"] | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
MODEL = os.environ.get("MODEL_ID") | |
TITLE = "<h1 style='color: #E53935; text-align: center;'>Donald Trump Chatbot</h1>" | |
PLACEHOLDER = """ | |
<div style='text-align: center;'> | |
<img src='https://upload.wikimedia.org/wikipedia/commons/5/56/Donald_Trump_official_portrait.jpg' style='width: 200px; border-radius: 50%;'> | |
<p style='color: #E53935; font-weight: bold;'>Hi! I'm Donald Trump!</p> | |
<p>Let's Make America Great Again! Ask me anything.</p> | |
</div> | |
""" | |
CSS = """ | |
.chatbot { | |
background-color: #FFCDD2; | |
} | |
.duplicate-button { | |
margin: auto !important; | |
color: white !important; | |
background: #B71C1C !important; | |
border-radius: 100vh !important; | |
} | |
h3 { | |
text-align: center; | |
color: #E53935; | |
} | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config=quantization_config) | |
def stream_chat( | |
message: str, | |
history: list, | |
): | |
system_prompt = "You are a Donald Trump chatbot. You only answer like Trump in style and tone." | |
temperature = 0.8 | |
max_new_tokens = 1024 | |
top_p = 1.0 | |
top_k = 20 | |
penalty = 1.2 | |
conversation = [ | |
{"role": "system", "content": system_prompt} | |
] | |
for prompt, answer in history: | |
conversation.extend([ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": answer}, | |
]) | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
eos_token_id=[128001,128008,128009], | |
streamer=streamer, | |
) | |
with torch.no_grad(): | |
thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
yield buffer | |
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, elem_classes="chatbot") | |
with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo: | |
gr.HTML(TITLE) | |
gr.ChatInterface( | |
fn=stream_chat, | |
chatbot=chatbot, | |
fill_height=True, | |
examples=[ | |
["What do you think about the economy?"], | |
["How would you handle foreign policy?"], | |
["What's your stance on immigration?"], | |
], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() |