File size: 4,044 Bytes
c773bb9
 
 
 
 
 
 
 
78c29ca
c773bb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48b916e
c773bb9
 
95bc271
 
 
c773bb9
 
48b916e
 
 
c773bb9
 
 
 
 
 
48b916e
 
c773bb9
48b916e
c773bb9
 
 
 
 
 
 
 
 
 
 
 
 
95bc271
c773bb9
 
 
 
 
 
 
 
 
 
 
 
 
 
48b916e
c773bb9
 
 
 
48b916e
c773bb9
 
 
 
 
 
 
48b916e
c773bb9
 
 
 
 
 
 
48b916e
c773bb9
 
 
 
 
 
 
48b916e
c773bb9
 
 
 
 
 
 
48b916e
c773bb9
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread

MODEL = "jwang2373/ChronoGemma"

TITLE = "<h1><center>LongWriter-llama3.1-8b</center></h1>"

PLACEHOLDER = """
<center>
<p>Hi! I'm LongWriter, capable of generating 10,000+ words. How can I assist you today?</p>
</center>
"""

CSS = """
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: black !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
}
"""

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
model = model.eval()

@spaces.GPU()
def stream_chat(
    message: str,
    history: list,
    system_prompt: str,
    temperature: float = 0.5,
    max_new_tokens: int = 32768,
    top_p: float = 1.0,
    top_k: int = 50,
):
    print(f'message: {message}')
    print(f'history: {history}')

    full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
    for prompt, answer in history:
        full_prompt += f"[INST]{prompt}[/INST]{answer}"
    full_prompt += f"[INST]{message}[/INST]"

    inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
    context_length = inputs.input_ids.shape[-1]

    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        inputs=inputs.input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        streamer=streamer,
    )

    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer

chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)

with gr.Blocks(css=CSS, theme="soft") as demo:
    gr.HTML(TITLE)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Textbox(
                value="You are a helpful assistant capable of generating long-form content.",
                label="System Prompt",
                render=False,
            ),
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.5,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=1024,
                maximum=32768,
                step=1024,
                value=32768,
                label="Max new tokens",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=1.0,
                label="Top p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=100,
                step=1,
                value=50,
                label="Top k",
                render=False,
            ),
        ],
        examples=[
            ["Write a 5000-word comprehensive guide on machine learning for beginners."],
            ["Create a detailed 3000-word business plan for a sustainable energy startup."],
            ["Compose a 2000-word short story set in a futuristic underwater city."],
            ["Develop a 4000-word research proposal on the potential effects of climate change on global food security."],
        ],
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch()