import gradio as gr import numpy as np import random import torch import io, json from PIL import Image import os.path from weight_fusion import compose_concepts from regionally_controlable_sampling import sample_image, build_model, prepare_text device = "cuda" if torch.cuda.is_available() else "cpu" power_device = "GPU" if torch.cuda.is_available() else "CPU" MAX_SEED = 100_000 def generate(region1_concept, region2_concept, prompt, region1_prompt, region2_prompt, negative_prompt, region_neg_prompt, seed, randomize_seed, sketch_adaptor_weight, keypose_adaptor_weight ): if randomize_seed: seed = random.randint(0, MAX_SEED) region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower() pretrained_model = merge(region1_concept, region2_concept) keypose_condition = 'multi-concept/pose_data/two_apart.png' region1 = '[0, 0, 512, 290]' region2 = '[0, 650, 512, 910]' region1_prompt = f'[<{region1_concept}1> <{region1_concept}2>, {region1_prompt}]' region2_prompt = f'[<{region2_concept}1> <{region2_concept}2>, {region2_prompt}]' prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}" result = infer(pretrained_model, prompt, prompt_rewrite, negative_prompt, seed, keypose_condition, keypose_adaptor_weight, # sketch_condition, # sketch_adaptor_weight, ) return result def merge(concept1, concept2): device = "cuda" if torch.cuda.is_available() else "cpu" c1, c2 = sorted([concept1, concept2]) assert c1!=c2 merge_name = c1+'_'+c2 save_path = f'experiments/multi-concept/{merge_name}' if os.path.isdir(save_path): print(f'{save_path} already exists. Collecting merged weights from existing weights...') else: os.makedirs(save_path) json_path = os.path.join(save_path,'merge_config.json') alpha = 1.8 data = [ { "lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth", "unet_alpha": alpha, "text_encoder_alpha": alpha, "concept_name": f"<{c1}1> <{c1}2>" }, { "lora_path": f"experiments/single-concept/{c2}/models/edlora_model-latest.pth", "unet_alpha": alpha, "text_encoder_alpha": alpha, "concept_name": f"<{c2}1> <{c2}2>" } ] with io.open(json_path,'w',encoding='utf8') as outfile: json.dump(data, outfile, indent = 4, ensure_ascii=False) compose_concepts( concept_cfg=json_path, optimize_textenc_iters=500, optimize_unet_iters=50, pretrained_model_path="nitrosocke/mo-di-diffusion", save_path=save_path, suffix='base', device=device, ) print(f'Merged weight for {c1}+{c2} saved in {save_path}!\n\n') modelbase_path = os.path.join(save_path,'combined_model_base') assert os.path.isdir(modelbase_path) # save_path = 'experiments/multi-concept/elsa_moana_weight18/combined_model_base' return modelbase_path def infer(pretrained_model, prompt, prompt_rewrite, negative_prompt='', seed=16141, keypose_condition=None, keypose_adaptor_weight=1.0, sketch_condition=None, sketch_adaptor_weight=0.0, region_sketch_adaptor_weight='', region_keypose_adaptor_weight='' ): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') pipe = build_model(pretrained_model, device) if sketch_condition is not None and os.path.exists(sketch_condition): sketch_condition = Image.open(sketch_condition).convert('L') width_sketch, height_sketch = sketch_condition.size print('use sketch condition') else: sketch_condition, width_sketch, height_sketch = None, 0, 0 print('skip sketch condition') if keypose_condition is not None and os.path.exists(keypose_condition): keypose_condition = Image.open(keypose_condition).convert('RGB') width_pose, height_pose = keypose_condition.size print('use pose condition') else: keypose_condition, width_pose, height_pose = None, 0, 0 print('skip pose condition') if width_sketch != 0 and width_pose != 0: assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size' width, height = max(width_pose, width_sketch), max(height_pose, height_sketch) kwargs = { 'sketch_condition': sketch_condition, 'keypose_condition': keypose_condition, 'height': height, 'width': width, } prompts = [prompt] prompts_rewrite = [prompt_rewrite] input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)] save_prompt = input_prompt[0][0] print(save_prompt) image = sample_image( pipe, input_prompt=input_prompt, input_neg_prompt=[negative_prompt] * len(input_prompt), generator=torch.Generator(device).manual_seed(seed), sketch_adaptor_weight=sketch_adaptor_weight, region_sketch_adaptor_weight=region_sketch_adaptor_weight, keypose_adaptor_weight=keypose_adaptor_weight, region_keypose_adaptor_weight=region_keypose_adaptor_weight, **kwargs) return image[0] examples_context = [ 'walking at Stanford university campus', 'in a castle', 'in the forest', 'in front of Eiffel tower' ] examples_region1 = ['wearing red hat, high resolution, best quality','bright smile, wearing pants, best quality'] examples_region2 = ['smilling, wearing blue shirt, high resolution, best quality'] css=""" #col-container { margin: 0 auto; max-width: 600px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f""" # Orthogonal Adaptation Currently running on {power_device}. """) prompt = gr.Text( label="ContextPrompt", show_label=False, max_lines=1, placeholder="Enter your context prompt for overall image", container=False, ) with gr.Row(): region1_concept = gr.Dropdown( ["Elsa", "Moana"], label="Character 1", info="Will add more characters later!" ) region2_concept = gr.Dropdown( ["Elsa", "Moana"], label="Character 2", info="Will add more characters later!" ) with gr.Row(): region1_prompt = gr.Textbox( label="Region1 Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt for character 1", container=False, ) region2_prompt = gr.Textbox( label="Region2 Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt for character 2", container=False, ) run_button = gr.Button("Run", scale=1) result = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Context Negative prompt", max_lines=1, value = 'saturated, cropped, worst quality, low quality', visible=False, ) region_neg_prompt = gr.Text( label="Regional Negative prompt", max_lines=1, value = 'shirtless, nudity, saturated, cropped, worst quality, low quality', visible=False, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): sketch_adaptor_weight = gr.Slider( label="Sketch Adapter Weight", minimum = 0, maximum = 1, step=0.01, value=0, ) keypose_adaptor_weight = gr.Slider( label="Keypose Adapter Weight", minimum = 0, maximum = 1, step= 0.01, value=1.0, ) gr.Examples( label = 'Context Prompt example', examples = examples_context, inputs = [prompt] ) with gr.Row(): gr.Examples( label = 'Region1 Prompt example', examples = examples_region1, inputs = [region1_prompt] ) gr.Examples( label = 'Region2 Prompt example', examples = [examples_region2], inputs = [region2_prompt] ) run_button.click( fn = generate, inputs = [region1_concept, region2_concept, prompt, region1_prompt, region2_prompt, negative_prompt, region_neg_prompt, seed, randomize_seed, # sketch_condition, # keypose_condition, sketch_adaptor_weight, keypose_adaptor_weight ], outputs = [result] ) demo.queue().launch(share=True)