taufiqdp's picture
Update app.py
e9f724a verified
raw
history blame contribute delete
No virus
3.44 kB
import os
import torch
import spaces
import subprocess
import gradio as gr
from threading import Thread
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
login(os.environ.get("HF_TOKEN"))
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_id = "microsoft/Phi-3-mini-128k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
attn_implementation="flash_attention_2"
)
@spaces.GPU()
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: int
):
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.append({"role": "user", "content": user})
conversation.append({"role": "assistant", "content": assistant})
conversation.append({"role": "user", "content": message})
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
input_ids, attention_mask = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
).to(model.device).values()
generate_kwargs = dict(
{"input_ids": input_ids, "attention_mask": attention_mask},
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
yield "".join(outputs)
gr.ChatInterface(
fn=generate,
title="πŸš€ Phi-3 mini 128k instruct",
description="",
additional_inputs=[
gr.Textbox(
label="System prompt",
lines=5,
value="You are a helpful digital assistant."
),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=2048,
step=1,
value=1024,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Can you provide ways to eat combinations of bananas and dragonfruits?"],
["Write a story about a dragon fruit that flies into outer space!"],
["I am going to Bali, what should I see"],
],
).queue().launch()