File size: 3,776 Bytes
fe2ff51
 
b02a933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a638ee1
b02a933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a307f6
 
b02a933
5a307f6
 
 
a638ee1
5a307f6
b02a933
 
 
 
 
 
 
 
fe2ff51
 
5a307f6
fe2ff51
 
 
 
 
 
 
b02a933
 
 
 
 
 
 
 
 
fe2ff51
b02a933
fe2ff51
 
 
 
5a307f6
fe2ff51
 
 
 
 
 
5a307f6
 
 
 
 
 
fe2ff51
5a307f6
fe2ff51
 
 
 
 
 
 
b02a933
5a307f6
 
fe2ff51
 
 
 
 
 
 
 
 
 
 
 
b02a933
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr




# from huggingface_hub import InferenceClient
# """
# For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
# """
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")




from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


class ChatClient:
    def __init__(self, model_path):
        """
        初始化客户端,加载模型和分词器到 GPU(如果可用)。
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
        self.model.eval()  # 设置为评估模式

    async def chat_completion(self, messages, max_tokens, stream=False, temperature=1.0, top_p=1.0):
        """
        生成对话回复。
        """
        # 将所有输入消息合并为一个字符串
        input_text = messages
        print(input_text)
        # 使用分词器处理输入文本
        inputs = self.tokenizer(input_text, return_tensors='pt').to(self.device)

        # 设置生成的参数
        gen_kwargs = {
            "max_length": inputs['input_ids'].shape[1] + max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": True
        }

        # 使用生成器生成文本
        output_sequences = self.model.generate(**inputs, **gen_kwargs)

        # 解码生成的文本
        # result_text = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
        # yield result_text

        # 解码生成的文本
        for sequence in output_sequences:
            result_text = self.tokenizer.decode(sequence, skip_special_tokens=True)
            await anyio.sleep(0)  # Yield control, simulating asynchronous operation
            yield result_text

# 创建客户端实例,指定模型路径
model_path = 'model/v3/'
client = ChatClient(model_path)






async def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # messages = [{"role": "system", "content": system_message}]
    #
    # for val in history:
    #     if val[0]:
    #         messages.append({"role": "user", "content": val[0]})
    #     if val[1]:
    #         messages.append({"role": "assistant", "content": val[1]})
    #
    # messages.append({"role": "user", "content": message})

    messages = system_message + message


    response = ""

    async for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        # print(message)
        # token = message
        # #token = message.choices[0].delta.content

        # response += token
        # yield response

        yield message

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Yahoo!ショッピングについての質問を回答してください。", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()