luodian commited on
Commit
20cefde
1 Parent(s): 229e9ea

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -0
README.md CHANGED
@@ -47,6 +47,7 @@ import torch
47
  import transformers
48
  from PIL import Image
49
  import sys
 
50
  sys.path.append("..")
51
  from otter.modeling_otter import OtterForConditionalGeneration
52
 
@@ -130,6 +131,7 @@ def get_response(input_data, prompt: str, model=None, image_processor=None, tens
130
  return_tensors="pt",
131
  )
132
 
 
133
  generated_text = model.generate(
134
  vision_x=vision_x.to(model.device, dtype=tensor_dtype),
135
  lang_x=lang_x["input_ids"].to(model.device),
@@ -137,6 +139,7 @@ def get_response(input_data, prompt: str, model=None, image_processor=None, tens
137
  max_new_tokens=512,
138
  num_beams=3,
139
  no_repeat_ngram_size=3,
 
140
  )
141
  parsed_output = (
142
  model.text_tokenizer.decode(generated_text[0])
@@ -151,6 +154,7 @@ def get_response(input_data, prompt: str, model=None, image_processor=None, tens
151
  )
152
  return parsed_output
153
 
 
154
  # ------------------- Main Function -------------------
155
  load_bit = "fp16"
156
  if load_bit == "fp16":
@@ -184,4 +188,5 @@ while True:
184
 
185
  if prompts_input.lower() == "quit":
186
  break
 
187
  ```
 
47
  import transformers
48
  from PIL import Image
49
  import sys
50
+
51
  sys.path.append("..")
52
  from otter.modeling_otter import OtterForConditionalGeneration
53
 
 
131
  return_tensors="pt",
132
  )
133
 
134
+ bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
135
  generated_text = model.generate(
136
  vision_x=vision_x.to(model.device, dtype=tensor_dtype),
137
  lang_x=lang_x["input_ids"].to(model.device),
 
139
  max_new_tokens=512,
140
  num_beams=3,
141
  no_repeat_ngram_size=3,
142
+ bad_words_ids=bad_words_id,
143
  )
144
  parsed_output = (
145
  model.text_tokenizer.decode(generated_text[0])
 
154
  )
155
  return parsed_output
156
 
157
+
158
  # ------------------- Main Function -------------------
159
  load_bit = "fp16"
160
  if load_bit == "fp16":
 
188
 
189
  if prompts_input.lower() == "quit":
190
  break
191
+
192
  ```