ujin-song commited on
Commit
408e8c2
β€’
1 Parent(s): 977d85a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -24,6 +24,7 @@ def generate(region1_concept,
24
  region_neg_prompt,
25
  seed,
26
  randomize_seed,
 
27
  sketch_adaptor_weight,
28
  keypose_adaptor_weight
29
  ):
@@ -41,7 +42,7 @@ def generate(region1_concept,
41
  seed = random.randint(0, MAX_SEED)
42
 
43
  region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
44
- pretrained_model = merge(region1_concept, region2_concept)
45
 
46
  with open('multi-concept/pose_data/pose.json') as f:
47
  d = json.load(f)
@@ -53,12 +54,12 @@ def generate(region1_concept,
53
  region1 = pose_image['region1']
54
  region2 = pose_image['region2']
55
 
56
- region_pos_prompt = "high resolution, best quality, highly detailed, sharp focus, expressive, 8k uhd, detailed, sophisticated"
57
  region1_prompt = f'<{region1_concept}1> <{region1_concept}2>, {region1_prompt}, {region_pos_prompt}'
58
  region2_prompt = f'<{region2_concept}1> <{region2_concept}2>, {region2_prompt}, {region_pos_prompt}'
59
  prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
60
  print(prompt_rewrite)
61
- prompt+=", Disney style photo, High resolution"
62
 
63
  result = infer(pretrained_model,
64
  prompt,
@@ -73,13 +74,13 @@ def generate(region1_concept,
73
 
74
  return result
75
 
76
- def merge(concept1, concept2):
77
  device = "cuda" if torch.cuda.is_available() else "cpu"
78
  c1, c2 = sorted([concept1, concept2])
79
  assert c1!=c2
80
  merge_name = c1+'_'+c2
81
 
82
- save_path = f'experiments/multi-concept/{merge_name}'
83
 
84
  if os.path.isdir(save_path):
85
  print(f'{save_path} already exists. Collecting merged weights from existing weights...')
@@ -87,7 +88,7 @@ def merge(concept1, concept2):
87
  else:
88
  os.makedirs(save_path)
89
  json_path = os.path.join(save_path,'merge_config.json')
90
- alpha = 1.8
91
  data = [
92
  {
93
  "lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth",
@@ -209,8 +210,8 @@ css="""
209
  with gr.Blocks(css=css) as demo:
210
  gr.Markdown(f"""
211
  # Orthogonal Adaptation
212
- Describe your world with a **πŸͺ„ text prompt (global and local)** and choose two characters to merge.
213
- Select their **πŸ‘― poses (spatial conditions)** for regionally controllable sampling to generate a unique image using our model.
214
  Let your creativity run wild! (Currently running on : {power_device} )
215
  """)
216
 
@@ -221,7 +222,7 @@ with gr.Blocks(css=css) as demo:
221
  # ### πŸͺ„ Global and Region prompts
222
  # """)
223
  # with gr.Group():
224
- with gr.Tab('πŸͺ„ Global and Region prompts'):
225
  prompt = gr.Text(
226
  label="ContextPrompt",
227
  show_label=False,
@@ -289,7 +290,7 @@ with gr.Blocks(css=css) as demo:
289
  # ### πŸ‘― Spatial Condition
290
  # """)
291
  # with gr.Group():
292
- with gr.Tab('πŸ‘― Spatial Condition '):
293
  gallery = gr.Gallery(label = "Select pose for characters",
294
  value = [obj[1]for obj in pose_image_list],
295
  elem_id = [obj[0]for obj in pose_image_list],
@@ -326,6 +327,14 @@ with gr.Blocks(css=css) as demo:
326
  )
327
 
328
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
 
329
 
330
  with gr.Row():
331
 
@@ -367,6 +376,7 @@ with gr.Blocks(css=css) as demo:
367
  randomize_seed,
368
  # sketch_condition,
369
  # keypose_condition,
 
370
  sketch_adaptor_weight,
371
  keypose_adaptor_weight
372
  ],
 
24
  region_neg_prompt,
25
  seed,
26
  randomize_seed,
27
+ alpha,
28
  sketch_adaptor_weight,
29
  keypose_adaptor_weight
30
  ):
 
42
  seed = random.randint(0, MAX_SEED)
43
 
44
  region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
45
+ pretrained_model = merge(region1_concept, region2_concept, alpha)
46
 
47
  with open('multi-concept/pose_data/pose.json') as f:
48
  d = json.load(f)
 
54
  region1 = pose_image['region1']
55
  region2 = pose_image['region2']
56
 
57
+ region_pos_prompt = "high resolution, best quality, highly detailed, sharp focus, sophisticated, Good anatomy, Clear facial features, Proportional body, Detailed clothing, Smooth textures"
58
  region1_prompt = f'<{region1_concept}1> <{region1_concept}2>, {region1_prompt}, {region_pos_prompt}'
59
  region2_prompt = f'<{region2_concept}1> <{region2_concept}2>, {region2_prompt}, {region_pos_prompt}'
60
  prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
61
  print(prompt_rewrite)
62
+ prompt+=", Disney style photo, High resolution, best quality, highly detailed, expressive,"
63
 
64
  result = infer(pretrained_model,
65
  prompt,
 
74
 
75
  return result
76
 
77
+ def merge(concept1, concept2, alpha):
78
  device = "cuda" if torch.cuda.is_available() else "cpu"
79
  c1, c2 = sorted([concept1, concept2])
80
  assert c1!=c2
81
  merge_name = c1+'_'+c2
82
 
83
+ save_path = f'experiments/multi-concept/{merge_name}--{int(alpha*10)}'
84
 
85
  if os.path.isdir(save_path):
86
  print(f'{save_path} already exists. Collecting merged weights from existing weights...')
 
88
  else:
89
  os.makedirs(save_path)
90
  json_path = os.path.join(save_path,'merge_config.json')
91
+ # alpha = 1.8
92
  data = [
93
  {
94
  "lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth",
 
210
  with gr.Blocks(css=css) as demo:
211
  gr.Markdown(f"""
212
  # Orthogonal Adaptation
213
+ Describe your world with a ** [πŸͺ„ Text Prompts] **(global and regional prompts) and choose two characters to merge.
214
+ Select their ** [ πŸ‘― Poses ] **(spatial conditions) for regionally controllable sampling to generate a unique image using our model.
215
  Let your creativity run wild! (Currently running on : {power_device} )
216
  """)
217
 
 
222
  # ### πŸͺ„ Global and Region prompts
223
  # """)
224
  # with gr.Group():
225
+ with gr.Tab('πŸͺ„ Text Prompts'):
226
  prompt = gr.Text(
227
  label="ContextPrompt",
228
  show_label=False,
 
290
  # ### πŸ‘― Spatial Condition
291
  # """)
292
  # with gr.Group():
293
+ with gr.Tab('πŸ‘― Poses '):
294
  gallery = gr.Gallery(label = "Select pose for characters",
295
  value = [obj[1]for obj in pose_image_list],
296
  elem_id = [obj[0]for obj in pose_image_list],
 
327
  )
328
 
329
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
330
+
331
+ alpha = gr.Slider(
332
+ label="Merge Weight",
333
+ minimum=1.2,
334
+ maximum=2.0,
335
+ step=0.1,
336
+ value=1.8,
337
+ )
338
 
339
  with gr.Row():
340
 
 
376
  randomize_seed,
377
  # sketch_condition,
378
  # keypose_condition,
379
+ alpha,
380
  sketch_adaptor_weight,
381
  keypose_adaptor_weight
382
  ],