import gradio as gr import sys import torch from PIL import Image import numpy as np from io import BytesIO import os from diffusers.utils import load_image from diffusers import ControlNetModel import numpy as np import torch from diffusers.image_processor import VaeImageProcessor from PIL import Image from pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained( "Salesforce/blipdiffusion-controlnet" ) controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint") blip_diffusion_pipe.controlnet = controlnet blip_diffusion_pipe.to(device) def make_inpaint_condition(image, image_mask): image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" image[image_mask > 0.5] = -1 # set as masked pixel image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) image = torch.from_numpy(image) return image css=''' .container {max-width: 1150px;margin: auto;padding-top: 1.5rem} .image_upload{min-height:500px} .image_upload [data-testid="image"], .image_upload [data-testid="image"] > div{min-height: 500px} .image_upload [data-testid="target"], .image_upload [data-testid="target"] > div{min-height: 500px} .image_upload .touch-none{display: flex} #output_image{min-height:500px;max-height=500px;} ''' def create_demo(): # load information from users HEIGHT, WIDTH=512,512 with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace","monospace"], primary_hue="lime", secondary_hue="emerald", neutral_hue="slate", ), css=css) as demo: gr.Markdown('# BLIP-Diffusion') with gr.Accordion('Instructions', open=False): gr.Markdown('1. Upload src image and draw mask') gr.Markdown('2. Upload tgt image') gr.Markdown('3. Input name of tgt object and description') gr.Markdown('4. Click `Generate` when it is ready!') with gr.Group(): with gr.Box(): with gr.Column(): with gr.Row() as main_blocks: # with gr.Column() as step_1: gr.Markdown('### Source Input and Add Mask') image = gr.Image(source='upload', shape=[HEIGHT,WIDTH], type='pil',#numpy', elem_classes="image_upload", label='Source Image', tool='sketch', brush_radius=60).style(height=500) src_input=image text_prompt = gr.Textbox(label='Prompt') run_button = gr.Button(label='Generate', value='Generate', variant="primary") # with gr.Column() as step_2: gr.Markdown('### Target Input') target = gr.Image(source='upload', shape=[HEIGHT,WIDTH], type='pil',#numpy', elem_classes="image_upload", label='Target Image' ).style(height=500) tgt_input=target style_subject = gr.Textbox(label='Target Object') with gr.Row() as output_blocks: with gr.Column() as output_step: gr.Markdown('### Output') output_image = gr.Gallery( label="Generated images", show_label=False, elem_id="output_image", ).style(height=500,containter=True) with gr.Accordion('Advanced options', open=False): num_inference_steps = gr.Slider(label='Steps', minimum=1, maximum=100, value=50, step=1) guidance_scale = gr.Slider(label='Text Guidance Scale', minimum=0.1, maximum=30.0, value=7.5, step=0.1) seed = gr.Slider(label='Seed', minimum=-1, maximum=2147483647, step=1, randomize=True) # Model inputs = [ src_input, tgt_input, text_prompt, style_subject, num_inference_steps, guidance_scale, seed, ] def generate(src_input, tgt_input, text_prompt, style_subject, num_inference_steps, guidance_scale, seed, ): if src_input is None or tgt_input is None: gr.Error("You must upload an image first.") return {output_image : None,} # model part tgt_subject = style_subject generator = torch.Generator(device="cpu").manual_seed(seed) init_image = src_input['image'] cldm_cond_image = src_input['mask'] control_image = make_inpaint_condition(init_image, cldm_cond_image) style_image = tgt_input negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" output = blip_diffusion_pipe( text_prompt, style_image, control_image, style_subject, tgt_subject, generator=generator, image=init_image, mask_image=cldm_cond_image, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, neg_prompt=negative_prompt, height=HEIGHT, width=WIDTH, ).images return {output_image : output,} run_button.click(fn=generate, inputs=inputs, outputs=[output_image]) return demo if __name__ == '__main__': demo = create_demo() demo.queue().launch()