ujin-song commited on
Commit
b81062c
•
1 Parent(s): fc13d95

Update app.py

Browse files

- implemented error message
- json structure modification

Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -17,7 +17,7 @@ MAX_SEED = 100_000
17
  def generate(region1_concept,
18
  region2_concept,
19
  prompt,
20
- pose_image_id,
21
  region1_prompt,
22
  region2_prompt,
23
  negative_prompt,
@@ -27,6 +27,15 @@ def generate(region1_concept,
27
  sketch_adaptor_weight,
28
  keypose_adaptor_weight
29
  ):
 
 
 
 
 
 
 
 
 
30
 
31
  if randomize_seed:
32
  seed = random.randint(0, MAX_SEED)
@@ -37,9 +46,10 @@ def generate(region1_concept,
37
  with open('multi-concept/pose_data/pose.json') as f:
38
  d = json.load(f)
39
 
40
- pose_image = {obj.pop('pose_id'):obj for obj in d}[int(pose_image_id)]
 
41
  print(pose_image)
42
- keypose_condition = pose_image['keypose_condition']
43
  region1 = pose_image['region1']
44
  region2 = pose_image['region2']
45
 
@@ -173,7 +183,7 @@ def infer(pretrained_model,
173
 
174
 
175
  def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
176
- return ''.join(c for c in evt.value['image']['orig_name'] if c.isdigit())
177
 
178
  examples_context = [
179
  'walking at Stanford university campus',
@@ -187,7 +197,7 @@ examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality
187
 
188
  with open('multi-concept/pose_data/pose.json') as f:
189
  d = json.load(f)
190
- pose_image_list = [(obj['pose_id'],obj['keypose_condition']) for obj in d]
191
 
192
  css="""
193
  #col-container {
@@ -210,7 +220,7 @@ with gr.Blocks(css=css) as demo:
210
  # gr.Markdown(f"""
211
  # ### 🪄 Global and Region prompts
212
  # """)
213
- # with gr.Group():
214
  with gr.Tab('🪄 Global and Region prompts'):
215
  prompt = gr.Text(
216
  label="ContextPrompt",
@@ -282,10 +292,10 @@ with gr.Blocks(css=css) as demo:
282
  value = [obj[1]for obj in pose_image_list],
283
  elem_id = [obj[0]for obj in pose_image_list],
284
  interactive=False, show_download_button=False,
285
- preview=True, height = 200, object_fit="scale-down")
286
 
287
- pose_image_id = gr.Textbox(visible=False)
288
- gallery.select(on_select, None, pose_image_id)
289
 
290
  run_button = gr.Button("Run", scale=1)
291
 
@@ -346,7 +356,7 @@ with gr.Blocks(css=css) as demo:
346
  inputs = [region1_concept,
347
  region2_concept,
348
  prompt,
349
- pose_image_id,
350
  region1_prompt,
351
  region2_prompt,
352
  negative_prompt,
 
17
  def generate(region1_concept,
18
  region2_concept,
19
  prompt,
20
+ pose_image_name,
21
  region1_prompt,
22
  region2_prompt,
23
  negative_prompt,
 
27
  sketch_adaptor_weight,
28
  keypose_adaptor_weight
29
  ):
30
+
31
+ if region1_concept==region2_concept:
32
+ raise gr.Error("Please choose two different characters for merging weights.")
33
+ if len(pose_image_name)==0:
34
+ raise gr.Error("Please select one spatial condition!")
35
+ if len(region1_prompt)==0 or len(region1_prompt)==0:
36
+ raise gr.Error("Your regional prompt cannot be empty.")
37
+ if len(prompt)==0:
38
+ raise gr.Error("Your global prompt cannot be empty.")
39
 
40
  if randomize_seed:
41
  seed = random.randint(0, MAX_SEED)
 
46
  with open('multi-concept/pose_data/pose.json') as f:
47
  d = json.load(f)
48
 
49
+ pose_image = {os.path.basename(obj['img_dir']):obj for obj in d}[pose_image_name]
50
+ # pose_image = {obj.pop('pose_id'):obj for obj in d}[int(pose_image_id)]
51
  print(pose_image)
52
+ keypose_condition = pose_image['img_dir']
53
  region1 = pose_image['region1']
54
  region2 = pose_image['region2']
55
 
 
183
 
184
 
185
  def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
186
+ return evt.value['image']['orig_name']
187
 
188
  examples_context = [
189
  'walking at Stanford university campus',
 
197
 
198
  with open('multi-concept/pose_data/pose.json') as f:
199
  d = json.load(f)
200
+ pose_image_list = [(obj['img_id'],obj['img_dir']) for obj in d]
201
 
202
  css="""
203
  #col-container {
 
220
  # gr.Markdown(f"""
221
  # ### 🪄 Global and Region prompts
222
  # """)
223
+ # with gr.Group():
224
  with gr.Tab('🪄 Global and Region prompts'):
225
  prompt = gr.Text(
226
  label="ContextPrompt",
 
292
  value = [obj[1]for obj in pose_image_list],
293
  elem_id = [obj[0]for obj in pose_image_list],
294
  interactive=False, show_download_button=False,
295
+ preview=True, height = 400, object_fit="scale-down")
296
 
297
+ pose_image_name = gr.Textbox(visible=False)
298
+ gallery.select(on_select, None, pose_image_name)
299
 
300
  run_button = gr.Button("Run", scale=1)
301
 
 
356
  inputs = [region1_concept,
357
  region2_concept,
358
  prompt,
359
+ pose_image_name,
360
  region1_prompt,
361
  region2_prompt,
362
  negative_prompt,