Spaces:
Paused
Paused
File size: 3,323 Bytes
d18942e 61ced1b d18942e 61ced1b d18942e 61ced1b d18942e 61ced1b d18942e 61ced1b d18942e 61ced1b d18942e 61ced1b d18942e 238ce74 d18942e 61ced1b d18942e 61ced1b d18942e 61ced1b d18942e 61ced1b d18942e 61ced1b d18942e 61ced1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread
MODEL_LIST = ["nawhgnuj/DonaldTrump-Llama-3.1-8B-Chat"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = os.environ.get("MODEL_ID")
TITLE = "<h1 style='color: #E53935; text-align: center;'>Donald Trump Chatbot</h1>"
PLACEHOLDER = """
<div style='text-align: center;'>
<img src='https://upload.wikimedia.org/wikipedia/commons/5/56/Donald_Trump_official_portrait.jpg' style='width: 200px; border-radius: 50%;'>
<p style='color: #E53935; font-weight: bold;'>Hi! I'm Donald Trump!</p>
<p>Let's Make America Great Again! Ask me anything.</p>
</div>
"""
CSS = """
.chatbot {
background-color: #FFCDD2;
}
.duplicate-button {
margin: auto !important;
color: white !important;
background: #B71C1C !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
color: #E53935;
}
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config)
@spaces.GPU()
def stream_chat(
message: str,
history: list,
):
system_prompt = "You are a Donald Trump chatbot. You only answer like Trump in style and tone."
temperature = 0.8
max_new_tokens = 1024
top_p = 1.0
top_k = 20
penalty = 1.2
conversation = [
{"role": "system", "content": system_prompt}
]
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=[128001,128008,128009],
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER, elem_classes="chatbot")
with gr.Blocks(css=CSS, theme=gr.themes.Default()) as demo:
gr.HTML(TITLE)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
examples=[
["What do you think about the economy?"],
["How would you handle foreign policy?"],
["What's your stance on immigration?"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch() |