BK-Lee commited on
Commit
3fb84e5
1 Parent(s): f019fdd
Files changed (2) hide show
  1. app.py +60 -42
  2. trol/load_trol.py +0 -14
app.py CHANGED
@@ -62,60 +62,78 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
62
  if "1.8B" in link:
63
  model = model_1_8
64
  tokenizer = tokenizer_1_8
 
65
  elif "3.8B" in link:
66
  model = model_3_8
67
  tokenizer = tokenizer_3_8
 
68
  elif "7B" in link:
69
  model = model_7
70
  tokenizer = tokenizer_7
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # cpu -> gpu
73
  for param in model.parameters():
74
  if not param.is_cuda:
75
  param.data = param.to(accel.device)
76
 
77
- # prompt type -> input prompt
78
- image_token_number = None
79
- if len(message['files']) == 1:
80
- # Image Load
81
- image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB"))
82
- if "3.8B" not in link:
83
- image_token_number = 1225
84
- image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
85
- inputs = [{'image': image.to(accel.device), 'question': message['text']}]
86
- elif len(message['files']) > 1:
87
- raise Exception("No way!")
88
- else:
89
- inputs = [{'question': message['text']}]
90
-
91
- # Text Generation
92
- with torch.inference_mode():
93
- # kwargs
94
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
95
-
96
- # Threading generation
97
- thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
98
- image_token_number=image_token_number,
99
- streamer=streamer,
100
- model=model,
101
- tokenizer=tokenizer,
102
- device=accel.device,
103
- temperature=temperature,
104
- new_max_token=new_max_token,
105
- top_p=top_p))
106
- thread.start()
107
-
108
- # generated text
109
- generated_text = ""
110
- for new_text in streamer:
111
- generated_text += new_text
112
- generated_text
113
-
114
- # Text decoding
115
- response = output_filtering(generated_text, model)
116
-
117
- # except:
118
- # response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."
 
119
 
120
  # private log print
121
  text = message['text']
 
62
  if "1.8B" in link:
63
  model = model_1_8
64
  tokenizer = tokenizer_1_8
65
+ path = "BK-Lee/TroL-1.8B"
66
  elif "3.8B" in link:
67
  model = model_3_8
68
  tokenizer = tokenizer_3_8
69
+ path = "BK-Lee/TroL-3.8B"
70
  elif "7B" in link:
71
  model = model_7
72
  tokenizer = tokenizer_7
73
+ path = "BK-Lee/TroL-7B"
74
 
75
+ # trol gating load
76
+ from huggingface_hub import hf_hub_download
77
+ try:
78
+ model.model.initialize_trol_gating()
79
+ model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
80
+ except:
81
+ model.language_model.model.initialize_trol_gating()
82
+ model.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
83
+
84
+ # X -> float16 conversion
85
+ for param in model.parameters():
86
+ if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
87
+ param.data = param.data.to(torch.float16)
88
+
89
  # cpu -> gpu
90
  for param in model.parameters():
91
  if not param.is_cuda:
92
  param.data = param.to(accel.device)
93
 
94
+ try:
95
+ # prompt type -> input prompt
96
+ image_token_number = None
97
+ if len(message['files']) == 1:
98
+ # Image Load
99
+ image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB"))
100
+ if "3.8B" not in link:
101
+ image_token_number = 1225
102
+ image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
103
+ inputs = [{'image': image.to(accel.device), 'question': message['text']}]
104
+ elif len(message['files']) > 1:
105
+ raise Exception("No way!")
106
+ else:
107
+ inputs = [{'question': message['text']}]
108
+
109
+ # Text Generation
110
+ with torch.inference_mode():
111
+ # kwargs
112
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
113
+
114
+ # Threading generation
115
+ thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
116
+ image_token_number=image_token_number,
117
+ streamer=streamer,
118
+ model=model,
119
+ tokenizer=tokenizer,
120
+ device=accel.device,
121
+ temperature=temperature,
122
+ new_max_token=new_max_token,
123
+ top_p=top_p))
124
+ thread.start()
125
+
126
+ # generated text
127
+ generated_text = ""
128
+ for new_text in streamer:
129
+ generated_text += new_text
130
+ generated_text
131
+
132
+ # Text decoding
133
+ response = output_filtering(generated_text, model)
134
+
135
+ except:
136
+ response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."
137
 
138
  # private log print
139
  text = message['text']
trol/load_trol.py CHANGED
@@ -80,18 +80,4 @@ def load_trol(link):
80
 
81
  # setting config
82
  setting_trol_config(trol, tok_trol, image_special_token)
83
-
84
- # trol gating load
85
- from huggingface_hub import hf_hub_download
86
- try:
87
- trol.model.initialize_trol_gating()
88
- trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
89
- except:
90
- trol.language_model.model.initialize_trol_gating()
91
- trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
92
-
93
- # X -> float16 conversion
94
- for param in trol.parameters():
95
- if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
96
- param.data = param.data.to(torch.float16)
97
  return trol, tok_trol
 
80
 
81
  # setting config
82
  setting_trol_config(trol, tok_trol, image_special_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return trol, tok_trol