artificialguybr commited on
Commit
c773bb9
1 Parent(s): e316496

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import spaces
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ import gradio as gr
7
+ from threading import Thread
8
+
9
+ MODEL = "THUDM/LongWriter-llama3.1-8b"
10
+
11
+ TITLE = "<h1><center>LongWriter-llama3.1-8b</center></h1>"
12
+
13
+ PLACEHOLDER = """
14
+ <center>
15
+ <p>Hi! I'm LongWriter, capable of generating 10,000+ words. How can I assist you today?</p>
16
+ </center>
17
+ """
18
+
19
+ CSS = """
20
+ .duplicate-button {
21
+ margin: auto !important;
22
+ color: white !important;
23
+ background: black !important;
24
+ border-radius: 100vh !important;
25
+ }
26
+ h3 {
27
+ text-align: center;
28
+ }
29
+ """
30
+
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
34
+ model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
35
+ model = model.eval()
36
+
37
+ @spaces.GPU()
38
+ def stream_chat(
39
+ message: str,
40
+ history: list,
41
+ system_prompt: str,
42
+ temperature: float = 0.5,
43
+ max_new_tokens: int = 32768,
44
+ top_p: float = 1.0,
45
+ top_k: int = 50,
46
+ ):
47
+ print(f'message: {message}')
48
+ print(f'history: {history}')
49
+
50
+ full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
51
+ for prompt, answer in history:
52
+ full_prompt += f"[INST]{prompt}[/INST]{answer}"
53
+ full_prompt += f"[INST]{message}[/INST]"
54
+
55
+ inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
56
+ context_length = inputs.input_ids.shape[-1]
57
+
58
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
59
+
60
+ generate_kwargs = dict(
61
+ inputs=inputs.input_ids,
62
+ max_new_tokens=max_new_tokens,
63
+ do_sample=True,
64
+ top_p=top_p,
65
+ top_k=top_k,
66
+ temperature=temperature,
67
+ num_beams=1,
68
+ streamer=streamer,
69
+ )
70
+
71
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
72
+ thread.start()
73
+
74
+ buffer = ""
75
+ for new_text in streamer:
76
+ buffer += new_text
77
+ yield buffer
78
+
79
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
80
+
81
+ with gr.Blocks(css=CSS, theme="soft") as demo:
82
+ gr.HTML(TITLE)
83
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
84
+ gr.ChatInterface(
85
+ fn=stream_chat,
86
+ chatbot=chatbot,
87
+ fill_height=True,
88
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
89
+ additional_inputs=[
90
+ gr.Textbox(
91
+ value="You are a helpful assistant capable of generating long-form content.",
92
+ label="System Prompt",
93
+ ),
94
+ gr.Slider(
95
+ minimum=0,
96
+ maximum=1,
97
+ step=0.1,
98
+ value=0.5,
99
+ label="Temperature",
100
+ ),
101
+ gr.Slider(
102
+ minimum=1024,
103
+ maximum=32768,
104
+ step=1024,
105
+ value=32768,
106
+ label="Max new tokens",
107
+ ),
108
+ gr.Slider(
109
+ minimum=0.0,
110
+ maximum=1.0,
111
+ step=0.1,
112
+ value=1.0,
113
+ label="Top p",
114
+ ),
115
+ gr.Slider(
116
+ minimum=1,
117
+ maximum=100,
118
+ step=1,
119
+ value=50,
120
+ label="Top k",
121
+ ),
122
+ ],
123
+ examples=[
124
+ ["Write a 5000-word comprehensive guide on machine learning for beginners."],
125
+ ["Create a detailed 3000-word business plan for a sustainable energy startup."],
126
+ ["Compose a 2000-word short story set in a futuristic underwater city."],
127
+ ["Develop a 4000-word research proposal on the potential effects of climate change on global food security."],
128
+ ],
129
+ cache_examples=False,
130
+ )
131
+
132
+ if __name__ == "__main__":
133
+ demo.launch()