CMLL commited on
Commit
4ed0b9b
1 Parent(s): fdf8c66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -61
app.py CHANGED
@@ -5,130 +5,125 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
- 仲景GPT-V2-1.8B
16
- 博极医源,精勤不倦。Unlocking the Wisdom of Traditional Chinese Medicine with AI.
17
  """
18
 
19
  LICENSE = """
20
  <p/>
21
  ---
22
- This demo is governed by the original licenses of [ZhongJing-2-1_8b](https://huggingface.co/CMLM/ZhongJing-2-1_8b) and [Qwen1.5-1.8B-Chat](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat).
 
23
  """
24
 
25
- peft_model_id = "CMLM/ZhongJing-2-1_8b"
26
- base_model_id = "Qwen/Qwen1.5-1.8B-Chat"
27
- model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map="auto")
28
- model.load_adapter(peft_model_id)
29
- tokenizer = AutoTokenizer.from_pretrained(
30
- "CMLM/ZhongJing-2-1_8b",
31
- padding_side="right",
32
- trust_remote_code=True,
33
- pad_token=''
34
- )
35
 
36
- @spaces.gpu()
37
  def generate(
38
  message: str,
 
 
39
  max_new_tokens: int = 1024,
40
  temperature: float = 0.6,
41
- top_p: float = 0.9,
42
  top_k: int = 50,
43
  repetition_penalty: float = 1.2,
44
  ) -> Iterator[str]:
 
 
 
 
 
 
45
 
46
- prompt = f"Question: {message}"
47
- messages = [
48
- {"role": "system", "content": "You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来 of Fudan University."},
49
- {"role": "user", "content": prompt}
50
- ]
 
51
 
52
- text = tokenizer.apply_chat_template(
53
- messages,
54
- tokenize=False,
55
- add_generation_prompt=True
56
- )
57
- input_ids = tokenizer([text], return_tensors="pt").input_ids
 
 
58
 
59
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
60
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
61
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
62
 
63
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
64
- generate_kwargs = dict(
65
- input_ids=input_ids,
66
- streamer=streamer,
67
- max_new_tokens=max_new_tokens,
68
- do_sample=True,
69
- top_p=top_p,
70
- top_k=top_k,
71
- temperature=temperature,
72
- num_beams=1,
73
- repetition_penalty=repetition_penalty,
74
- )
75
- t = Thread(target=model.generate, kwargs=generate_kwargs)
76
  t.start()
77
 
78
  outputs = []
79
- for text in streamer:
80
- outputs.append(text)
81
  yield "".join(outputs)
82
 
83
- chat_interface = gr.Interface(
84
  fn=generate,
85
- inputs=[
86
- gr.components.Textbox(label="Enter your question"),
87
- gr.components.Slider(
88
  label="Max new tokens",
89
- minimum=1,
90
  maximum=MAX_MAX_NEW_TOKENS,
91
  step=1,
92
  value=DEFAULT_MAX_NEW_TOKENS,
93
  ),
94
- gr.components.Slider(
95
  label="Temperature",
96
  minimum=0.1,
97
  maximum=4.0,
98
- step=0.1,
99
  value=0.6,
100
  ),
101
- gr.components.Slider(
102
  label="Top-p (nucleus sampling)",
103
  minimum=0.05,
104
  maximum=1.0,
105
  step=0.05,
106
  value=0.9,
107
  ),
108
- gr.components.Slider(
109
  label="Top-k",
110
  minimum=1,
111
  maximum=1000,
112
  step=1,
113
  value=50,
114
  ),
115
- gr.components.Slider(
116
- label="Repetition penalty",
117
  minimum=1.0,
118
  maximum=2.0,
119
  step=0.05,
120
  value=1.2,
121
  ),
122
  ],
123
- outputs="text",
124
- title="仲景GPT-V2-1.8B",
125
- description=DESCRIPTION,
126
- allow_flagging=False,
127
  examples=[
128
  ["请问气虚体质有哪些症状表现?"],
129
  ["简单介绍一下中医的五行学说。"],
130
  ["桑螵蛸是什么?有什么功效作用?"],
131
- ],
 
132
  )
133
 
134
  with gr.Blocks(css="style.css") as demo:
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import pipeline, AutoTokenizer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
+ # ZhongJing 2 1.8B Merge
16
+ This Space demonstrates model [CMLL/ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge) for text generation. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
17
  """
18
 
19
  LICENSE = """
20
  <p/>
21
  ---
22
+ As a derivative work of [CMLL/ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge),
23
+ this demo is governed by the original [license](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge/LICENSE).
24
  """
25
 
26
+ if not torch.cuda.is_available():
27
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
28
+
29
+ if torch.cuda.is_available():
30
+ model_id = "CMLL/ZhongJing-2-1_8b-merge"
31
+ pipe = pipeline("text-generation", model=model_id)
32
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
33
+ tokenizer.use_default_system_prompt = False
 
 
34
 
35
+ @spaces.GPU
36
  def generate(
37
  message: str,
38
+ chat_history: list[tuple[str, str]],
39
+ system_prompt: str,
40
  max_new_tokens: int = 1024,
41
  temperature: float = 0.6,
42
+ top_p: float = 0.9,
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
  ) -> Iterator[str]:
46
+ conversation = []
47
+ if system_prompt:
48
+ conversation.append({"role": "system", "content": system_prompt})
49
+ for user, assistant in chat_history:
50
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
51
+ conversation.append({"role": "user", "content": message})
52
 
53
+ input_text = "\n".join([f"{entry['role']}: {entry['content']}" for entry in conversation])
54
+ inputs = tokenizer(input_text, return_tensors="pt")
55
+ if inputs.input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
56
+ inputs = {k: v[:, -MAX_INPUT_TOKEN_LENGTH:] for k, v in inputs.items()}
57
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
58
+ inputs = inputs.to(pipe.device)
59
 
60
+ generate_kwargs = {
61
+ "max_new_tokens": max_new_tokens,
62
+ "do_sample": True,
63
+ "top_p": top_p,
64
+ "top_k": top_k,
65
+ "temperature": temperature,
66
+ "repetition_penalty": repetition_penalty,
67
+ }
68
 
69
+ def run_generation():
70
+ return pipe(inputs.input_ids, **generate_kwargs)
 
71
 
72
+ t = Thread(target=run_generation)
 
 
 
 
 
 
 
 
 
 
 
 
73
  t.start()
74
 
75
  outputs = []
76
+ for text in run_generation():
77
+ outputs.append(text['generated_text'])
78
  yield "".join(outputs)
79
 
80
+ chat_interface = gr.ChatInterface(
81
  fn=generate,
82
+ additional_inputs=[
83
+ gr.Textbox(label="System prompt", lines=6),
84
+ gr.Slider(
85
  label="Max new tokens",
86
+ minimum=1,
87
  maximum=MAX_MAX_NEW_TOKENS,
88
  step=1,
89
  value=DEFAULT_MAX_NEW_TOKENS,
90
  ),
91
+ gr.Slider(
92
  label="Temperature",
93
  minimum=0.1,
94
  maximum=4.0,
95
+ step=0.1,
96
  value=0.6,
97
  ),
98
+ gr.Slider(
99
  label="Top-p (nucleus sampling)",
100
  minimum=0.05,
101
  maximum=1.0,
102
  step=0.05,
103
  value=0.9,
104
  ),
105
+ gr.Slider(
106
  label="Top-k",
107
  minimum=1,
108
  maximum=1000,
109
  step=1,
110
  value=50,
111
  ),
112
+ gr.Slider(
113
+ label="Repetition penalty",
114
  minimum=1.0,
115
  maximum=2.0,
116
  step=0.05,
117
  value=1.2,
118
  ),
119
  ],
120
+ stop_btn=None,
 
 
 
121
  examples=[
122
  ["请问气虚体质有哪些症状表现?"],
123
  ["简单介绍一下中医的五行学说。"],
124
  ["桑螵蛸是什么?有什么功效作用?"],
125
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
126
+ ],
127
  )
128
 
129
  with gr.Blocks(css="style.css") as demo: