qgyd2021's picture
[update]add code
3177298
raw
history blame
No virus
6.48 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
import gradio as gr
from transformers import AutoModel, AutoTokenizer
from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer
# from transformers.utils.quantization_config import BitsAndBytesConfig
import torch
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--train_subset", default="train.jsonl", type=str)
parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
parser.add_argument(
"--pretrained_model_name_or_path",
# default="YeungNLP/firefly-chatglm2-6b",
default=(project_path / "trained_models/firefly_chatglm2_6b_intent").as_posix(),
type=str
)
parser.add_argument("--output_file", default="result.xlsx", type=str)
parser.add_argument("--max_new_tokens", default=512, type=int)
parser.add_argument("--top_p", default=0.9, type=float)
parser.add_argument("--temperature", default=0.35, type=float)
parser.add_argument("--repetition_penalty", default=1.0, type=float)
parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str)
args = parser.parse_args()
return args
description = """
## ChatGLM-6B
基于 [firefly-chatglm2-6b](https://huggingface.co/YeungNLP/firefly-chatglm2-6b) 模型, 在 [telemarketing_intent](https://huggingface.co/datasets/qgyd2021/telemarketing_intent/tree/main/data/prompt) 的 prompt 数据集上训练, 目的是实现 `电话营销` 场景的 1-shot 意图识别.
该分类任务有一百多个类别, 但标注数据总是只有 3 万, 并且有一半是 "无关领域", 实现思路是:
1. 首先采用传统算法做硬分类, 然后提取概率 top 10 的标签.
2. 将 top 10 的标签作为候选标签, 并为每个标签提供一个句子示例.
3. 要求 LLM 输出目标句子的类别.
Gradio 布署代码参考了: https://huggingface.co/spaces/aodianyun/ChatGLM-6B
"""
examples = [
"""我们在做电话营销场景的意图识别任务, 可选的意图如下:
否定(不是); 礼貌用语; 否定答复; 肯定(需要); 用户正忙; 否定(不需要); 无关领域; 否定(没有); 否定(不用了); 价格太高
如果你认为给定的句子不属于这些意图中的任务一个, 你可以回答: 不知道.
Tips:
1. 如果候选意图中有 "无关领域", 当你不知道时, 则它有可能属于无关领域.
Examples:
---------
ExampleSentence: 其实不是
ExampleIntent: 否定(不是)
ExampleSentence: 嗯!嘿嘿!早点休息,晚安咯
ExampleIntent: 礼貌用语
ExampleSentence: 没问诶
ExampleIntent: 否定答复
ExampleSentence: 不好意思都需要谢谢
ExampleIntent: 肯定(需要)
ExampleSentence: 对呀我在忙
ExampleIntent: 用户正忙
ExampleSentence: 。嗯也也不需要吧唉呀现在不需要那个啊嗯
ExampleIntent: 否定(不需要)
ExampleSentence: 我的处理器需要很少的电源。
ExampleIntent: 无关领域
ExampleSentence: 。呃我好像没有在太平洋买过保险,吧拜拜
ExampleIntent: 否定(没有)
ExampleSentence: 嗯不用谢谢
ExampleIntent: 否定(不用了)
ExampleSentence: 费用贵。
ExampleIntent: 价格太高
---------
Sentence: 。嗯各位不需要,啊谢谢
Intent:""",
"""我们在做电话营销场景的意图识别任务, 可选的意图如下:
语音信箱; 无关领域; 查物品信息; 污言秽语; 疑问(时间); 疑问(数值); 答时间; 查收费方式; 价格太高; 答数值
如果你认为给定的句子不属于这些意图中的任务一个, 你可以回答: 不知道.
Tips:
1. 如果候选意图中有 "无关领域", 当你不知道时, 则它有可能属于无关领域.
Examples:
---------
ExampleSentence: 我们留言。
ExampleIntent: 语音信箱
ExampleSentence: 很刚刚打
ExampleIntent: 无关领域
ExampleSentence: 什么东西我听
ExampleIntent: 查物品信息
ExampleSentence: 知道!AV女优!日本人的骄傲!
ExampleIntent: 污言秽语
ExampleSentence: 最后期限
ExampleIntent: 疑问(时间)
ExampleSentence: 一共借了多少钱
ExampleIntent: 疑问(数值)
ExampleSentence: 22号
ExampleIntent: 答时间
ExampleSentence: 运费
ExampleIntent: 查收费方式
ExampleSentence: 利息高
ExampleIntent: 价格太高
ExampleSentence: 20。
ExampleIntent: 答数值
---------
Sentence: 。对啊什么东西啊我6月份出来的
Intent:"""
]
def main():
args = get_args()
use_cpu = os.environ.get("USE_CPU", "all")
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True)
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
if tokenizer.__class__.__name__ == "QWenTokenizer":
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
if not use_cpu:
model = AutoModel.from_pretrained(
args.pretrained_model_name_or_path,
trust_remote_code=True
).half().cuda()
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained_model_name_or_path,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
device_map="auto",
offload_folder="./offload",
offload_state_dict=True,
# load_in_4bit=True,
)
model = model.eval()
def fn(inputs, history=None):
if history is None:
history = list()
with torch.no_grad():
response, history = model.chat(tokenizer, inputs, history)
return history, history
with gr.Blocks() as blocks:
gr.Markdown(value=description)
state = gr.State([])
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
with gr.Row():
with gr.Column(scale=4):
text = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
with gr.Column(scale=1):
button = gr.Button("Generate")
gr.Examples(examples, text)
text.submit(fn, [text, state], [chatbot, state])
button.click(fn, [text, state], [chatbot, state])
blocks.queue().launch()
return
if __name__ == '__main__':
main()