qgyd2021's picture
[update]edit main
83d59a1
raw
history blame
No virus
6.19 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from threading import Thread
import gradio as gr
from transformers import AutoModel, AutoTokenizer
from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.streamers import TextIteratorStreamer
import torch
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--train_subset", default="train.jsonl", type=str)
parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
parser.add_argument(
"--pretrained_model_name_or_path",
default=(project_path / "trained_models/qwen_7b_chinese_modern_poetry").as_posix(),
type=str
)
parser.add_argument("--output_file", default="result.xlsx", type=str)
parser.add_argument("--max_new_tokens", default=512, type=int)
parser.add_argument("--top_p", default=0.9, type=float)
parser.add_argument("--temperature", default=0.35, type=float)
parser.add_argument("--repetition_penalty", default=1.0, type=float)
parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str)
args = parser.parse_args()
return args
description = """
## Qwen-7B
基于 [Qwen-7B](https://huggingface.co/qgyd2021/Qwen-7B) 模型, 在 [chinese_modern_poetry](https://huggingface.co/datasets/Iess/chinese_modern_poetry) 数据集上训练了 2 个 epoch.
可用于生成现代诗. 如下:
使用下列意象写一首现代诗:智慧,刀刃.
"""
examples = [
"使用下列意象写一首现代诗:石头,森林",
"使用下列意象写一首现代诗:花,纱布",
"使用下列意象写一首现代诗:山壁,彩虹,诗句,山坡,泪",
"使用下列意象写一首现代诗:味道,黄金,名字,银子,女人",
"使用下列意象写一首现代诗:乳房,触感,车速,星星,路灯"
]
def main():
args = get_args()
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True)
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
if tokenizer.__class__.__name__ == "QWenTokenizer":
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
model = AutoModelForCausalLM.from_pretrained(
args.pretrained_model_name_or_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
device_map="auto",
offload_folder="./offload",
offload_state_dict=True,
# load_in_4bit=True,
)
model = model.bfloat16().eval()
def fn_non_stream(text: str):
input_ids = tokenizer(
text,
return_tensors="pt",
add_special_tokens=False,
).input_ids.to(args.device)
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device)
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device)
input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=args.max_new_tokens,
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
eos_token_id=tokenizer.eos_token_id
)
outputs = outputs.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs)
response = response.strip().replace(tokenizer.eos_token, "").strip()
return [(text, response)]
def fn_stream(text: str):
text = str(text).strip()
input_ids = tokenizer(
text,
return_tensors="pt",
add_special_tokens=False,
).input_ids.to(args.device)
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device)
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device)
input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_kwargs = dict(
inputs=input_ids,
max_new_tokens=args.max_new_tokens,
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
output = ""
for output_ in streamer:
output_ = output_.replace(text, "")
output_ = output_.replace(tokenizer.eos_token, "")
output += output_
result = [(text, output)]
chatbot.value = result
yield result
with gr.Blocks() as blocks:
gr.Markdown(value=description)
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
with gr.Column(scale=1):
submit_button = gr.Button("💬Submit")
with gr.Column(scale=1):
clear_button = gr.Button("🗑️Clear", variant="secondary")
gr.Examples(examples, text_box)
text_box.submit(fn_stream, [text_box], [chatbot])
submit_button.click(fn_stream, [text_box], [chatbot])
clear_button.click(
fn=lambda: ("", ""),
outputs=[text_box, chatbot],
queue=False,
api_name=False,
)
blocks.queue().launch()
return
if __name__ == '__main__':
main()