ujin-song commited on
Commit
1509ef8
1 Parent(s): ff8f8b9

demo version2

Browse files

spatial condition added

Files changed (1) hide show
  1. app.py +134 -103
app.py CHANGED
@@ -17,6 +17,7 @@ MAX_SEED = 100_000
17
  def generate(region1_concept,
18
  region2_concept,
19
  prompt,
 
20
  region1_prompt,
21
  region2_prompt,
22
  negative_prompt,
@@ -33,13 +34,19 @@ def generate(region1_concept,
33
  region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
34
  pretrained_model = merge(region1_concept, region2_concept)
35
 
36
- keypose_condition = 'multi-concept/pose_data/two_apart.png'
37
- region1 = '[0, 0, 512, 290]'
38
- region2 = '[0, 650, 512, 910]'
 
 
 
 
 
39
 
40
  region1_prompt = f'[<{region1_concept}1> <{region1_concept}2>, {region1_prompt}]'
41
  region2_prompt = f'[<{region2_concept}1> <{region2_concept}2>, {region2_prompt}]'
42
  prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
 
43
 
44
  result = infer(pretrained_model,
45
  prompt,
@@ -164,6 +171,10 @@ def infer(pretrained_model,
164
 
165
  return image[0]
166
 
 
 
 
 
167
  examples_context = [
168
  'walking at Stanford university campus',
169
  'in a castle',
@@ -174,6 +185,10 @@ examples_context = [
174
  examples_region1 = ['wearing red hat, high resolution, best quality','bright smile, wearing pants, best quality']
175
  examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality']
176
 
 
 
 
 
177
  css="""
178
  #col-container {
179
  margin: 0 auto;
@@ -182,124 +197,140 @@ css="""
182
  """
183
 
184
  with gr.Blocks(css=css) as demo:
 
 
 
 
185
 
186
- with gr.Column(elem_id="col-container"):
187
- gr.Markdown(f"""
188
- # Orthogonal Adaptation
189
- Currently running on {power_device}.
190
- """)
191
- prompt = gr.Text(
192
- label="ContextPrompt",
193
- show_label=False,
194
- max_lines=1,
195
- placeholder="Enter your context prompt for overall image",
196
- container=False,
197
- )
198
- with gr.Row():
199
-
200
- region1_concept = gr.Dropdown(
201
- ["Elsa", "Moana"],
202
- label="Character 1",
203
- info="Will add more characters later!"
204
- )
205
- region2_concept = gr.Dropdown(
206
- ["Elsa", "Moana"],
207
- label="Character 2",
208
- info="Will add more characters later!"
209
- )
210
-
211
- with gr.Row():
212
-
213
- region1_prompt = gr.Textbox(
214
- label="Region1 Prompt",
215
- show_label=False,
216
- max_lines=2,
217
- placeholder="Enter your prompt for character 1",
218
- container=False,
219
- )
220
-
221
- region2_prompt = gr.Textbox(
222
- label="Region2 Prompt",
223
- show_label=False,
224
- max_lines=2,
225
- placeholder="Enter your prompt for character 2",
226
- container=False,
227
- )
228
-
229
- run_button = gr.Button("Run", scale=1)
230
-
231
- result = gr.Image(label="Result", show_label=False)
232
-
233
- with gr.Accordion("Advanced Settings", open=False):
234
-
235
- negative_prompt = gr.Text(
236
- label="Context Negative prompt",
237
- max_lines=1,
238
- value = 'saturated, cropped, worst quality, low quality',
239
- visible=False,
240
- )
241
-
242
- region_neg_prompt = gr.Text(
243
- label="Regional Negative prompt",
244
- max_lines=1,
245
- value = 'shirtless, nudity, saturated, cropped, worst quality, low quality',
246
- visible=False,
247
- )
248
-
249
- seed = gr.Slider(
250
- label="Seed",
251
- minimum=0,
252
- maximum=MAX_SEED,
253
- step=1,
254
- value=0,
255
- )
256
-
257
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
258
 
 
 
 
 
 
 
 
259
  with gr.Row():
260
 
261
- sketch_adaptor_weight = gr.Slider(
262
- label="Sketch Adapter Weight",
263
- minimum = 0,
264
- maximum = 1,
265
- step=0.01,
266
- value=0,
267
  )
268
-
269
- keypose_adaptor_weight = gr.Slider(
270
- label="Keypose Adapter Weight",
271
- minimum = 0,
272
- maximum = 1,
273
- step= 0.01,
274
- value=1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  )
 
 
 
 
 
 
 
276
 
 
 
277
 
278
- gr.Examples(
279
- label = 'Context Prompt example',
280
- examples = examples_context,
281
- inputs = [prompt]
282
- )
283
 
284
- with gr.Row():
285
- gr.Examples(
286
- label = 'Region1 Prompt example',
287
- examples = examples_region1,
288
- inputs = [region1_prompt]
289
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  gr.Examples(
292
- label = 'Region2 Prompt example',
293
- examples = [examples_region2],
294
- inputs = [region2_prompt]
295
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
 
297
 
298
  run_button.click(
299
  fn = generate,
300
  inputs = [region1_concept,
301
  region2_concept,
302
  prompt,
 
303
  region1_prompt,
304
  region2_prompt,
305
  negative_prompt,
 
17
  def generate(region1_concept,
18
  region2_concept,
19
  prompt,
20
+ pose_image_id,
21
  region1_prompt,
22
  region2_prompt,
23
  negative_prompt,
 
34
  region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
35
  pretrained_model = merge(region1_concept, region2_concept)
36
 
37
+ with open('multi-concept/pose_data/pose.json') as f:
38
+ d = json.load(f)
39
+
40
+ pose_image = {obj.pop('pose_id'):obj for obj in d}[int(pose_image_id)]
41
+ print(pose_image)
42
+ keypose_condition = pose_image['keypose_condition']
43
+ region1 = pose_image['region1']
44
+ region2 = pose_image['region2']
45
 
46
  region1_prompt = f'[<{region1_concept}1> <{region1_concept}2>, {region1_prompt}]'
47
  region2_prompt = f'[<{region2_concept}1> <{region2_concept}2>, {region2_prompt}]'
48
  prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
49
+ print(prompt_rewrite)
50
 
51
  result = infer(pretrained_model,
52
  prompt,
 
171
 
172
  return image[0]
173
 
174
+
175
+ def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
176
+ return ''.join(c for c in evt.value['image']['orig_name'] if c.isdigit())
177
+
178
  examples_context = [
179
  'walking at Stanford university campus',
180
  'in a castle',
 
185
  examples_region1 = ['wearing red hat, high resolution, best quality','bright smile, wearing pants, best quality']
186
  examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality']
187
 
188
+ with open('multi-concept/pose_data/pose.json') as f:
189
+ d = json.load(f)
190
+ pose_image_list = [(obj['pose_id'],obj['keypose_condition']) for obj in d]
191
+
192
  css="""
193
  #col-container {
194
  margin: 0 auto;
 
197
  """
198
 
199
  with gr.Blocks(css=css) as demo:
200
+ gr.Markdown(f"""
201
+ # Orthogonal Adaptation
202
+ Currently running on : {power_device}
203
+ """)
204
 
205
+ with gr.Row():
206
+ with gr.Column(elem_id="col-container", scale=2):
207
+ gr.Markdown(f"""
208
+ ### 🕹️ Global and Region prompts:
209
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ prompt = gr.Text(
212
+ label="ContextPrompt",
213
+ show_label=False,
214
+ max_lines=1,
215
+ placeholder="Enter your context(global) prompt",
216
+ container=False,
217
+ )
218
  with gr.Row():
219
 
220
+ region1_concept = gr.Dropdown(
221
+ ["Elsa", "Moana"],
222
+ label="Character 1",
223
+ info="Will add more characters later!"
 
 
224
  )
225
+ region2_concept = gr.Dropdown(
226
+ ["Elsa", "Moana"],
227
+ label="Character 2",
228
+ info="Will add more characters later!"
229
+ )
230
+
231
+ with gr.Row():
232
+
233
+ region1_prompt = gr.Textbox(
234
+ label="Region1 Prompt",
235
+ show_label=False,
236
+ max_lines=2,
237
+ placeholder="Enter your regional prompt for character 1",
238
+ container=False,
239
+ )
240
+
241
+ region2_prompt = gr.Textbox(
242
+ label="Region2 Prompt",
243
+ show_label=False,
244
+ max_lines=2,
245
+ placeholder="Enter your regional prompt for character 2",
246
+ container=False,
247
  )
248
+
249
+ gr.Markdown(f"### 🧭 Spatial Condition for regionally controllable sampling: ")
250
+ gallery = gr.Gallery(label = "Select pose for characters",
251
+ value = [obj[1]for obj in pose_image_list],
252
+ elem_id = [obj[0]for obj in pose_image_list],
253
+ interactive=False, show_download_button=False,
254
+ preview=True, height = 200, object_fit="scale-down")
255
 
256
+ pose_image_id = gr.Textbox(visible=False)
257
+ gallery.select(on_select, None, pose_image_id)
258
 
259
+ run_button = gr.Button("Run", scale=1)
 
 
 
 
260
 
261
+ with gr.Accordion("Advanced Settings", open=False):
262
+
263
+ negative_prompt = gr.Text(
264
+ label="Context Negative prompt",
265
+ max_lines=1,
266
+ value = 'saturated, cropped, worst quality, low quality',
267
+ visible=False,
268
+ )
269
+
270
+ region_neg_prompt = gr.Text(
271
+ label="Regional Negative prompt",
272
+ max_lines=1,
273
+ value = 'shirtless, nudity, saturated, cropped, worst quality, low quality',
274
+ visible=False,
275
+ )
276
+
277
+ seed = gr.Slider(
278
+ label="Seed",
279
+ minimum=0,
280
+ maximum=MAX_SEED,
281
+ step=1,
282
+ value=0,
283
+ )
284
+
285
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
286
+
287
+ with gr.Row():
288
+
289
+ sketch_adaptor_weight = gr.Slider(
290
+ label="Sketch Adapter Weight",
291
+ minimum = 0,
292
+ maximum = 1,
293
+ step=0.01,
294
+ value=0,
295
+ )
296
+
297
+ keypose_adaptor_weight = gr.Slider(
298
+ label="Keypose Adapter Weight",
299
+ minimum = 0,
300
+ maximum = 1,
301
+ step= 0.01,
302
+ value=1.0,
303
+ )
304
+
305
+ with gr.Column(scale=1):
306
 
307
  gr.Examples(
308
+ label = 'Global Prompt example',
309
+ examples = examples_context,
310
+ inputs = [prompt]
311
+ )
312
+
313
+ with gr.Row():
314
+ gr.Examples(
315
+ label = 'Region1 Prompt example',
316
+ examples = examples_region1,
317
+ inputs = [region1_prompt]
318
+ )
319
+
320
+ gr.Examples(
321
+ label = 'Region2 Prompt example',
322
+ examples = [examples_region2],
323
+ inputs = [region2_prompt]
324
+ )
325
 
326
+ result = gr.Image(label="Result", show_label=False)
327
 
328
  run_button.click(
329
  fn = generate,
330
  inputs = [region1_concept,
331
  region2_concept,
332
  prompt,
333
+ pose_image_id,
334
  region1_prompt,
335
  region2_prompt,
336
  negative_prompt,