duzx16 commited on
Commit
591fa87
1 Parent(s): 85ba2d2

Add system prompt

Browse files
Files changed (2) hide show
  1. modeling_chatglm.py +14 -17
  2. tokenization_chatglm.py +7 -2
modeling_chatglm.py CHANGED
@@ -1001,19 +1001,15 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1001
  response = response.replace("[[训练时间]]", "2023年")
1002
  return response
1003
 
1004
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
1005
- inputs = tokenizer.build_chat_input(query, history=history)
1006
- inputs = inputs.to(self.device)
1007
- return inputs
1008
-
1009
- def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
1010
- inputs = tokenizer.build_chat_input(query)
1011
  inputs = inputs.to(self.device)
1012
  return inputs
1013
 
1014
  @torch.inference_mode()
1015
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
1016
- do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
 
1017
  if history is None:
1018
  history = []
1019
  if logits_processor is None:
@@ -1021,7 +1017,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1021
  logits_processor.append(InvalidScoreLogitsProcessor())
1022
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1023
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1024
- inputs = self.build_inputs(tokenizer, query, history=history)
1025
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")]
1026
  outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1027
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
@@ -1031,21 +1027,22 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1031
  return response, history
1032
 
1033
  @torch.inference_mode()
1034
- def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
1035
- max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1036
- return_past_key_values=False, **kwargs):
1037
  if history is None:
1038
  history = []
1039
  if logits_processor is None:
1040
  logits_processor = LogitsProcessorList()
1041
  logits_processor.append(InvalidScoreLogitsProcessor())
1042
- eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")]
 
1043
  gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1044
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1045
- if past_key_values is None and not return_past_key_values:
1046
- inputs = self.build_inputs(tokenizer, query, history=history)
1047
  else:
1048
- inputs = self.build_stream_inputs(tokenizer, query, history=history)
1049
  if past_key_values is not None:
1050
  past_length = past_key_values[0][0].shape[0]
1051
  if self.transformer.pre_seq_len is not None:
 
1001
  response = response.replace("[[训练时间]]", "2023年")
1002
  return response
1003
 
1004
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None):
1005
+ inputs = tokenizer.build_chat_input(query, history=history, system=system)
 
 
 
 
 
1006
  inputs = inputs.to(self.device)
1007
  return inputs
1008
 
1009
  @torch.inference_mode()
1010
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None,
1011
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1012
+ **kwargs):
1013
  if history is None:
1014
  history = []
1015
  if logits_processor is None:
 
1017
  logits_processor.append(InvalidScoreLogitsProcessor())
1018
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1019
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1020
+ inputs = self.build_inputs(tokenizer, query, history=history, system=system)
1021
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")]
1022
  outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1023
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
 
1027
  return response, history
1028
 
1029
  @torch.inference_mode()
1030
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None,
1031
+ past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1032
+ logits_processor=None, return_past_key_values=False, **kwargs):
1033
  if history is None:
1034
  history = []
1035
  if logits_processor is None:
1036
  logits_processor = LogitsProcessorList()
1037
  logits_processor.append(InvalidScoreLogitsProcessor())
1038
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1039
+ tokenizer.get_command("<|observation|>")]
1040
  gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1041
  "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1042
+ if past_key_values is None:
1043
+ inputs = self.build_inputs(tokenizer, query, history=history, system=system)
1044
  else:
1045
+ inputs = self.build_inputs(tokenizer, query)
1046
  if past_key_values is not None:
1047
  past_length = past_key_values[0][0].shape[0]
1048
  if self.transformer.pre_seq_len is not None:
tokenization_chatglm.py CHANGED
@@ -67,7 +67,9 @@ class SPTokenizer:
67
 
68
  def convert_id_to_token(self, index):
69
  """Converts an index (integer) in a token (str) using the vocab."""
70
- if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
 
 
71
  return ""
72
  return self.sp_model.IdToPiece(index)
73
 
@@ -171,10 +173,13 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
171
  prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
172
  return prefix_tokens
173
 
174
- def build_chat_input(self, query, history=None):
175
  if history is None:
176
  history = []
177
  input_ids = []
 
 
 
178
  for i, (old_query, old_response) in enumerate(history):
179
  input_ids.extend(
180
  [self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_query))
 
67
 
68
  def convert_id_to_token(self, index):
69
  """Converts an index (integer) in a token (str) using the vocab."""
70
+ if index in self.index_special_tokens:
71
+ return self.index_special_tokens[index]
72
+ if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
73
  return ""
74
  return self.sp_model.IdToPiece(index)
75
 
 
173
  prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
174
  return prefix_tokens
175
 
176
+ def build_chat_input(self, query, history=None, system=None):
177
  if history is None:
178
  history = []
179
  input_ids = []
180
+ if system is not None:
181
+ input_ids.extend(
182
+ [self.get_command("<|system|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(system))
183
  for i, (old_query, old_response) in enumerate(history):
184
  input_ids.extend(
185
  [self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_query))