artificialguybr's picture
Update app.py
95bc271 verified
raw
history blame
3.69 kB
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
MODEL = "THUDM/LongWriter-llama3.1-8b"
TITLE = "<h1><center>LongWriter-llama3.1-8b</center></h1>"
PLACEHOLDER = """
<center>
<p>Hi! I'm LongWriter, capable of generating 10,000+ words. How can I assist you today?</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
model = model.eval()
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.5,
max_new_tokens: int = 32768,
top_p: float = 1.0,
top_k: int = 50,
):
full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
for prompt, answer in history:
full_prompt += f"[INST]{prompt}[/INST]{answer}"
full_prompt += f"[INST]{message}[/INST]"
inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs=inputs.input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs=[
gr.Textbox(
value="You are a helpful assistant capable of generating long-form content.",
label="System Prompt",
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.5,
label="Temperature",
),
gr.Slider(
minimum=1024,
maximum=32768,
step=1024,
value=32768,
label="Max new tokens",
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="Top p",
),
gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top k",
),
],
examples=[
["Write a 5000-word comprehensive guide on machine learning for beginners."],
["Create a detailed 3000-word business plan for a sustainable energy startup."],
["Compose a 2000-word short story set in a futuristic underwater city."],
["Develop a 4000-word research proposal on the potential effects of climate change on global food security."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()