EZO / app.py
HODACHI's picture
Update app.py
6c5690e verified
raw
history blame contribute delete
No virus
1.86 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
MODEL_ID = "HODACHI/EZO-Common-9B-gemma-2-it"
DTYPE = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cuda",
torch_dtype=DTYPE,
)
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
chat = []
for user, assistant in history:
chat.append({"role": "user", "content": user})
chat.append({"role": "assistant", "content": assistant})
chat.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()