BOREA_DEMO / app.py
MakiAi's picture
Update app.py
e995600 verified
raw
history blame contribute delete
No virus
5.42 kB
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread
# モデルの定義
MODELS = {
"Borea-Phi-3.5-mini-Jp": "AXCXEPT/Borea-Phi-3.5-mini-Instruct-Jp",
"EZO-Common-9B": "HODACHI/EZO-Common-9B-gemma-2-it",
"Phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
}
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# タイトルとプレースホルダーを日本語に変更
TITLE = "<h1><center>Borea/EZO デモアプリ</center></h1>"
PLACEHOLDER = """
<center>
<p>こんにちは、私はAIアシスタントです。何でも質問してください。</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model = None
tokenizer = None
def load_model(model_name):
global model, tokenizer
model_path = MODELS[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config
)
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.8,
max_new_tokens: int = 1024,
top_p: float = 1.0,
top_k: int = 20,
repetition_penalty: float = 1.2,
model_name: str = "Phi-3.5-mini"
):
global model, tokenizer
if model is None or tokenizer is None or model.name_or_path != MODELS[model_name]:
load_model(model_name)
print(f'message: {message}')
print(f'history: {history}')
conversation = [
{"role": "system", "content": system_prompt}
]
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False if temperature == 0 else True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
with gr.Blocks(css=CSS, theme='ParityError/Interstellar') as demo:
gr.HTML(TITLE)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs=[
gr.Textbox(
value="あなたは親切なアシスタントです。",
label="システムプロンプト",
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="温度 (Temperature)",
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=1024,
label="最大新規トークン数",
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
label="top_p",
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="top_k",
),
gr.Slider(
minimum=1.0,
maximum=2.0,
step=0.1,
value=1.2,
label="繰り返しペナルティ",
),
gr.Dropdown(
choices=list(MODELS.keys()),
value="Borea-Phi-3.5-mini-Jp",
label="モデル選択",
),
],
examples=[
["語彙の勉強を手伝ってください。空欄を埋めるための文章を書いてください。私は正しい選択肢を選びます。"],
["子供のアート作品でできる5つの創造的なことを教えてください。捨てたくはないのですが、散らかってしまいます。"],
["ローマ帝国についてのランダムな面白い事実を教えてください。"],
["ウェブサイトの固定ヘッダーのCSSとJavaScriptのコードスニペットを見せてください。"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()