CMLL commited on
Commit
92b045a
1 Parent(s): ff9b690

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -35
app.py CHANGED
@@ -3,10 +3,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import gradio as gr
5
 
6
- # Initialize
7
  peft_model_id = "CMLM/ZhongJing-2-1_8b"
8
  base_model_id = "Qwen/Qwen1.5-1.8B-Chat"
9
- model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map="auto")
 
 
10
  model.load_adapter(peft_model_id)
11
  tokenizer = AutoTokenizer.from_pretrained(
12
  "CMLM/ZhongJing-2-1_8b",
@@ -15,30 +17,9 @@ tokenizer = AutoTokenizer.from_pretrained(
15
  pad_token=''
16
  )
17
 
18
- # Single turn chat
19
- @spaces.GPU
20
- def single_turn_chat(question):
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- model.to(device)
23
-
24
- prompt = f"Question: {question}"
25
- messages = [
26
- {"role": "system", "content": "You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来 of Fudan University."},
27
- {"role": "user", "content": prompt}
28
- ]
29
- input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
30
- model_inputs = tokenizer([input_text], return_tensors="pt").to(device)
31
- generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
32
- generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
33
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
- return response
35
-
36
- # Multi-turn chat
37
  @spaces.GPU
38
  def multi_turn_chat(question, chat_history=None):
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- model.to(device)
41
-
42
  if not isinstance(question, str):
43
  raise ValueError("The question must be a string.")
44
 
@@ -76,16 +57,7 @@ def multi_turn_chat(question, chat_history=None):
76
  def clear_history():
77
  return [], []
78
 
79
- # Single turn interface
80
- single_turn_interface = gr.Interface(
81
- fn=single_turn_chat,
82
- inputs=["text"],
83
- outputs="text",
84
- title="仲景GPT-V2-1.8B 单轮对话",
85
- description="博极医源,精勤不倦。Unlocking the Wisdom of Traditional Chinese Medicine with AI."
86
- )
87
-
88
- # Multi-turn interface
89
  with gr.Blocks() as multi_turn_interface:
90
  chatbot = gr.Chatbot(label="仲景GPT-V2-1.8B 多轮对话")
91
  state = gr.State([])
@@ -97,7 +69,8 @@ with gr.Blocks() as multi_turn_interface:
97
 
98
  submit_button.click(multi_turn_chat, [user_input, state], [chatbot, state])
99
  user_input.submit(multi_turn_chat, [user_input, state], [chatbot, state])
 
 
100
 
101
- single_turn_interface.launch()
102
  multi_turn_interface.launch()
103
 
 
3
  import torch
4
  import gradio as gr
5
 
6
+ # 初始化
7
  peft_model_id = "CMLM/ZhongJing-2-1_8b"
8
  base_model_id = "Qwen/Qwen1.5-1.8B-Chat"
9
+
10
+ device = "cuda"
11
+ model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map={"": device}).to(device)
12
  model.load_adapter(peft_model_id)
13
  tokenizer = AutoTokenizer.from_pretrained(
14
  "CMLM/ZhongJing-2-1_8b",
 
17
  pad_token=''
18
  )
19
 
20
+ #多轮对话
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @spaces.GPU
22
  def multi_turn_chat(question, chat_history=None):
 
 
 
23
  if not isinstance(question, str):
24
  raise ValueError("The question must be a string.")
25
 
 
57
  def clear_history():
58
  return [], []
59
 
60
+ # 多轮界面
 
 
 
 
 
 
 
 
 
61
  with gr.Blocks() as multi_turn_interface:
62
  chatbot = gr.Chatbot(label="仲景GPT-V2-1.8B 多轮对话")
63
  state = gr.State([])
 
69
 
70
  submit_button.click(multi_turn_chat, [user_input, state], [chatbot, state])
71
  user_input.submit(multi_turn_chat, [user_input, state], [chatbot, state])
72
+ clear_button = gr.Button("清除对话历史")
73
+ clear_button.click(clear_history, [], [chatbot, state])
74
 
 
75
  multi_turn_interface.launch()
76