jwang2373's picture
Update app.py
d4f2047 verified
raw
history blame
4.34 kB
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoConfig
import gradio as gr
from threading import Thread
MODEL = "jwang2373/UW-SBEL-ChronoPhi-4b-it"
TITLE = "<h1><center>UW-SBEL-ChronoPhi-4b</center></h1>"
PLACEHOLDER = """
<center>
<p>Hi! I'm a PyChrono Digital Twin expert. How can I assist you today?</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the fine-tuned model configuration
config = AutoConfig.from_pretrained("jwang2373/UW-SBEL-ChronoPhi-4b-it")
base_config = AutoConfig.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
fine_tuned_config = AutoConfig.from_pretrained("jwang2373/UW-SBEL-ChronoPhi-4b-it")
print(base_config)
print(fine_tuned_config)
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto",config=config)
model = model.eval()
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.1,
max_new_tokens: int = 32768,
top_p: float = 1.0,
top_k: int = 50,
):
print(f'message: {message}')
print(f'history: {history}')
full_prompt = f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
for prompt, answer in history:
full_prompt += f"[INST]{prompt}[/INST]{answer}"
full_prompt += f"[INST]{message}[/INST]"
inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device)
context_length = inputs.input_ids.shape[-1]
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs=inputs.input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Textbox(
value="You are a PyChrono expert.",
label="System Prompt",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.5,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=4096,
step=1024,
value=4096,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="Top p",
render=False,
),
gr.Slider(
minimum=1,
maximum=100,
step=1,
value=100,
label="Top k",
render=False,
),
],
examples=[
["Run a PyChrono simulation of a sedan driving on a flat surface with a detailed vehicle dynamics model."],
["Run a real-time simulation of an HMMWV vehicle on a bumpy and textured road."],
["Set up a Curiosity rover driving simulation on flat, rigid ground in PyChrono."],
["Simulate a FEDA vehicle driving on rigid terrain in PyChrono."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()