BK-Lee commited on
Commit
a56928d
1 Parent(s): eb8fafa
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +3 -5
  3. trol/load_trol.py +1 -1
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .vscode
2
+ __pycache__
3
+ */__pycache__
app.py CHANGED
@@ -38,8 +38,6 @@ model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
38
  # loading model
39
  model_7, tokenizer_7 = load_trol(link='TroL-7B')
40
 
41
- print()
42
-
43
  def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
44
 
45
  # propagation
@@ -85,7 +83,7 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
85
  if "3.8B" not in link:
86
  image_token_number = 1225
87
  image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
88
- inputs = [{'image': image, 'question': message['text']}]
89
  elif len(message['files']) > 1:
90
  raise Exception("No way!")
91
  else:
@@ -116,7 +114,7 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
116
 
117
  # Text decoding
118
  response = output_filtering(generated_text, model)
119
-
120
  except:
121
  response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."
122
 
@@ -138,7 +136,7 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
138
  yield buffer
139
 
140
  demo = gr.ChatInterface(fn=bot_streaming,
141
- additional_inputs = [gr.Radio(["3.8B"], label="Size", info="Select one model size", value="3.8B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
142
  additional_inputs_accordion="Generation Hyperparameters",
143
  theme=gr.themes.Soft(),
144
  title="TroL",
 
38
  # loading model
39
  model_7, tokenizer_7 = load_trol(link='TroL-7B')
40
 
 
 
41
  def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
42
 
43
  # propagation
 
83
  if "3.8B" not in link:
84
  image_token_number = 1225
85
  image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
86
+ inputs = [{'image': image.to(accel.device), 'question': message['text']}]
87
  elif len(message['files']) > 1:
88
  raise Exception("No way!")
89
  else:
 
114
 
115
  # Text decoding
116
  response = output_filtering(generated_text, model)
117
+
118
  except:
119
  response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."
120
 
 
136
  yield buffer
137
 
138
  demo = gr.ChatInterface(fn=bot_streaming,
139
+ additional_inputs = [gr.Radio(["1.8B", "3.8B"], label="Size", info="Select one model size", value="3.8B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
140
  additional_inputs_accordion="Generation Hyperparameters",
141
  theme=gr.themes.Soft(),
142
  title="TroL",
trol/load_trol.py CHANGED
@@ -23,7 +23,7 @@ def load_trol(link):
23
  from transformers import LlamaTokenizerFast as TroLTokenizer
24
  bits = 8
25
  path = TROL_3_8B
26
- bit_quant_skip = ["vision_model", "mlp1", "lm_head"]
27
 
28
  elif link == 'TroL-7B':
29
  from .arch_internlm2.modeling_trol import TroLForCausalLM
 
23
  from transformers import LlamaTokenizerFast as TroLTokenizer
24
  bits = 8
25
  path = TROL_3_8B
26
+ bit_quant_skip = ["vision_model", "vision_proj", "lm_head"]
27
 
28
  elif link == 'TroL-7B':
29
  from .arch_internlm2.modeling_trol import TroLForCausalLM