import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import torch from threading import Thread MODEL_ID = "HODACHI/EZO-Common-9B-gemma-2-it" DTYPE = torch.bfloat16 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="cuda", torch_dtype=DTYPE, ) def respond( message, history: list[tuple[str, str]], max_tokens, temperature, top_p, ): chat = [] for user, assistant in history: chat.append({"role": "user", "content": user}) chat.append({"role": "assistant", "content": assistant}) chat.append({"role": "user", "content": message}) prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( input_ids=inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() response = "" for new_text in streamer: response += new_text yield response demo = gr.ChatInterface( respond, additional_inputs=[ gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], ) if __name__ == "__main__": demo.launch()