y0un92 commited on
Commit
fb0809e
1 Parent(s): 27694ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -37
app.py CHANGED
@@ -7,7 +7,6 @@ import traceback
7
  import re
8
  import torch
9
  import argparse
10
- import numpy as np
11
  from transformers import AutoModel, AutoTokenizer
12
 
13
  # README, How to run demo on different devices
@@ -19,43 +18,270 @@ from transformers import AutoModel, AutoTokenizer
19
  # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.5.py --device mps
20
 
21
  # Argparser
22
- # test.py
23
- import torch
24
- from PIL import Image
25
- from transformers import AutoModel, AutoTokenizer
26
- import bitsandbytes as bnb
27
- import accelerate
28
 
 
29
  model = AutoModel.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5-int4', trust_remote_code=True)
30
- tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5-int4', trust_remote_code=True)
31
  model.eval()
32
 
33
- image = Image.open('xx.jpg').convert('RGB')
34
- question = 'What is in the image?'
35
- msgs = [{'role': 'user', 'content': question}]
36
-
37
- res = model.chat(
38
- image=image,
39
- msgs=msgs,
40
- tokenizer=tokenizer,
41
- sampling=True, # if sampling=False, beam_search will be used by default
42
- temperature=0.7,
43
- # system_prompt='' # pass system_prompt if needed
44
- )
45
- print(res)
46
-
47
- ## if you want to use streaming, please make sure sampling=True and stream=True
48
- ## the model.chat will return a generator
49
- res = model.chat(
50
- image=image,
51
- msgs=msgs,
52
- tokenizer=tokenizer,
53
- sampling=True,
54
- temperature=0.7,
55
- stream=True
56
- )
57
-
58
- generated_text = ""
59
- for new_text in res:
60
- generated_text += new_text
61
- print(new_text, flush=True, end='')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import re
8
  import torch
9
  import argparse
 
10
  from transformers import AutoModel, AutoTokenizer
11
 
12
  # README, How to run demo on different devices
 
18
  # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.5.py --device mps
19
 
20
  # Argparser
21
+ parser = argparse.ArgumentParser(description='demo')
22
+ parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
23
+ args = parser.parse_args()
24
+ device = args.device
25
+ assert device in ['cuda', 'mps']
 
26
 
27
+ # Load model
28
  model = AutoModel.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5-int4', trust_remote_code=True)
29
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
30
  model.eval()
31
 
32
+
33
+
34
+ ERROR_MSG = "Error, please retry"
35
+ model_name = 'MiniCPM-Llama3-V 2.5-int4'
36
+
37
+ form_radio = {
38
+ 'choices': ['Beam Search', 'Sampling'],
39
+ #'value': 'Beam Search',
40
+ 'value': 'Sampling',
41
+ 'interactive': True,
42
+ 'label': 'Decode Type'
43
+ }
44
+ # Beam Form
45
+ num_beams_slider = {
46
+ 'minimum': 0,
47
+ 'maximum': 5,
48
+ 'value': 3,
49
+ 'step': 1,
50
+ 'interactive': True,
51
+ 'label': 'Num Beams'
52
+ }
53
+ repetition_penalty_slider = {
54
+ 'minimum': 0,
55
+ 'maximum': 3,
56
+ 'value': 1.2,
57
+ 'step': 0.01,
58
+ 'interactive': True,
59
+ 'label': 'Repetition Penalty'
60
+ }
61
+ repetition_penalty_slider2 = {
62
+ 'minimum': 0,
63
+ 'maximum': 3,
64
+ 'value': 1.05,
65
+ 'step': 0.01,
66
+ 'interactive': True,
67
+ 'label': 'Repetition Penalty'
68
+ }
69
+ max_new_tokens_slider = {
70
+ 'minimum': 1,
71
+ 'maximum': 4096,
72
+ 'value': 1024,
73
+ 'step': 1,
74
+ 'interactive': True,
75
+ 'label': 'Max New Tokens'
76
+ }
77
+
78
+ top_p_slider = {
79
+ 'minimum': 0,
80
+ 'maximum': 1,
81
+ 'value': 0.8,
82
+ 'step': 0.05,
83
+ 'interactive': True,
84
+ 'label': 'Top P'
85
+ }
86
+ top_k_slider = {
87
+ 'minimum': 0,
88
+ 'maximum': 200,
89
+ 'value': 100,
90
+ 'step': 1,
91
+ 'interactive': True,
92
+ 'label': 'Top K'
93
+ }
94
+ temperature_slider = {
95
+ 'minimum': 0,
96
+ 'maximum': 2,
97
+ 'value': 0.7,
98
+ 'step': 0.05,
99
+ 'interactive': True,
100
+ 'label': 'Temperature'
101
+ }
102
+
103
+
104
+ def create_component(params, comp='Slider'):
105
+ if comp == 'Slider':
106
+ return gr.Slider(
107
+ minimum=params['minimum'],
108
+ maximum=params['maximum'],
109
+ value=params['value'],
110
+ step=params['step'],
111
+ interactive=params['interactive'],
112
+ label=params['label']
113
+ )
114
+ elif comp == 'Radio':
115
+ return gr.Radio(
116
+ choices=params['choices'],
117
+ value=params['value'],
118
+ interactive=params['interactive'],
119
+ label=params['label']
120
+ )
121
+ elif comp == 'Button':
122
+ return gr.Button(
123
+ value=params['value'],
124
+ interactive=True
125
+ )
126
+
127
+ @spaces.GPU(duration=120)
128
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
129
+ default_params = {"stream": False, "sampling": False, "num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
130
+ if params is None:
131
+ params = default_params
132
+ if img is None:
133
+ yield "Error, invalid image, please upload a new image"
134
+ else:
135
+ try:
136
+ image = img.convert('RGB')
137
+ answer = model.chat(
138
+ image=image,
139
+ msgs=msgs,
140
+ tokenizer=tokenizer,
141
+ **params
142
+ )
143
+ # if params['stream'] is False:
144
+ # res = re.sub(r'(<box>.*</box>)', '', answer)
145
+ # res = res.replace('<ref>', '')
146
+ # res = res.replace('</ref>', '')
147
+ # res = res.replace('<box>', '')
148
+ # answer = res.replace('</box>', '')
149
+ # else:
150
+ for char in answer:
151
+ yield char
152
+ except Exception as err:
153
+ print(err)
154
+ traceback.print_exc()
155
+ yield ERROR_MSG
156
+
157
+
158
+ def upload_img(image, _chatbot, _app_session):
159
+ image = Image.fromarray(image)
160
+
161
+ _app_session['sts']=None
162
+ _app_session['ctx']=[]
163
+ _app_session['img']=image
164
+ _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
165
+ return _chatbot, _app_session
166
+
167
+
168
+ def respond(_chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
169
+ _question = _chat_bot[-1][0]
170
+ print('<Question>:', _question)
171
+ if _app_cfg.get('ctx', None) is None:
172
+ _chat_bot[-1][1] = 'Please upload an image to start'
173
+ yield (_chat_bot, _app_cfg)
174
+ else:
175
+ _context = _app_cfg['ctx'].copy()
176
+ if _context:
177
+ _context.append({"role": "user", "content": _question})
178
+ else:
179
+ _context = [{"role": "user", "content": _question}]
180
+ if params_form == 'Beam Search':
181
+ params = {
182
+ 'sampling': False,
183
+ 'stream': False,
184
+ 'num_beams': num_beams,
185
+ 'repetition_penalty': repetition_penalty,
186
+ "max_new_tokens": 896
187
+ }
188
+ else:
189
+ params = {
190
+ 'sampling': True,
191
+ 'stream': True,
192
+ 'top_p': top_p,
193
+ 'top_k': top_k,
194
+ 'temperature': temperature,
195
+ 'repetition_penalty': repetition_penalty_2,
196
+ "max_new_tokens": 896
197
+ }
198
+
199
+ gen = chat(_app_cfg['img'], _context, None, params)
200
+ _chat_bot[-1][1] = ""
201
+ for _char in gen:
202
+ _chat_bot[-1][1] += _char
203
+ _context[-1]["content"] += _char
204
+ yield (_chat_bot, _app_cfg)
205
+
206
+
207
+ def request(_question, _chat_bot, _app_cfg):
208
+ _chat_bot.append((_question, None))
209
+ return '', _chat_bot, _app_cfg
210
+
211
+
212
+ def regenerate_button_clicked(_question, _chat_bot, _app_cfg):
213
+ if len(_chat_bot) <= 1:
214
+ _chat_bot.append(('Regenerate', 'No question for regeneration.'))
215
+ return '', _chat_bot, _app_cfg
216
+ elif _chat_bot[-1][0] == 'Regenerate':
217
+ return '', _chat_bot, _app_cfg
218
+ else:
219
+ _question = _chat_bot[-1][0]
220
+ _chat_bot = _chat_bot[:-1]
221
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
222
+ return request(_question, _chat_bot, _app_cfg)
223
+ # return respond(_chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
224
+
225
+
226
+ def clear_button_clicked(_question, _chat_bot, _app_cfg, _bt_pic):
227
+ _chat_bot.clear()
228
+ _app_cfg['sts'] = None
229
+ _app_cfg['ctx'] = None
230
+ _app_cfg['img'] = None
231
+ _bt_pic = None
232
+ return '', _chat_bot, _app_cfg, _bt_pic
233
+
234
+
235
+ with gr.Blocks() as demo:
236
+ with gr.Row():
237
+ with gr.Column(scale=1, min_width=300):
238
+ params_form = create_component(form_radio, comp='Radio')
239
+ with gr.Accordion("Beam Search") as beams_according:
240
+ num_beams = create_component(num_beams_slider)
241
+ repetition_penalty = create_component(repetition_penalty_slider)
242
+ with gr.Accordion("Sampling") as sampling_according:
243
+ top_p = create_component(top_p_slider)
244
+ top_k = create_component(top_k_slider)
245
+ temperature = create_component(temperature_slider)
246
+ repetition_penalty_2 = create_component(repetition_penalty_slider2)
247
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
248
+ clear = create_component({'value': 'Clear'}, comp='Button')
249
+ with gr.Column(scale=3, min_width=500):
250
+ app_session = gr.State({'sts':None,'ctx':None,'img':None})
251
+ bt_pic = gr.Image(label="Upload an image to start")
252
+ chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
253
+ txt_message = gr.Textbox(label="Input text")
254
+
255
+ clear.click(
256
+ clear_button_clicked,
257
+ [txt_message, chat_bot, app_session, bt_pic],
258
+ [txt_message, chat_bot, app_session, bt_pic],
259
+ queue=False
260
+ )
261
+ txt_message.submit(
262
+ request,
263
+ #[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
264
+ [txt_message, chat_bot, app_session],
265
+ [txt_message, chat_bot, app_session],
266
+ queue=False
267
+ ).then(
268
+ respond,
269
+ [chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
270
+ [chat_bot, app_session]
271
+ )
272
+ regenerate.click(
273
+ regenerate_button_clicked,
274
+ [txt_message, chat_bot, app_session],
275
+ [txt_message, chat_bot, app_session],
276
+ queue=False
277
+ ).then(
278
+ respond,
279
+ [chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
280
+ [chat_bot, app_session]
281
+ )
282
+ bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
283
+
284
+ # launch
285
+ #demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
286
+ demo.queue()
287
+ demo.launch()