testdemo / app.py
xwwu's picture
Rename app-2.py to app.py
1eafe56
raw
history blame contribute delete
No virus
7.42 kB
import gradio as gr
import sys
import torch
from PIL import Image
import numpy as np
from io import BytesIO
import os
from diffusers.utils import load_image
from diffusers import ControlNetModel
import numpy as np
import torch
from diffusers.image_processor import VaeImageProcessor
from PIL import Image
from pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
"Salesforce/blipdiffusion-controlnet"
)
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint")
blip_diffusion_pipe.controlnet = controlnet
blip_diffusion_pipe.to(device)
def make_inpaint_condition(image, image_mask):
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
image[image_mask > 0.5] = -1 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
css='''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
.image_upload{min-height:500px}
.image_upload [data-testid="image"], .image_upload [data-testid="image"] > div{min-height: 500px}
.image_upload [data-testid="target"], .image_upload [data-testid="target"] > div{min-height: 500px}
.image_upload .touch-none{display: flex}
#output_image{min-height:500px;max-height=500px;}
'''
def create_demo():
# load information from users
HEIGHT, WIDTH=512,512
with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace","monospace"],
primary_hue="lime",
secondary_hue="emerald",
neutral_hue="slate",
), css=css) as demo:
gr.Markdown('# BLIP-Diffusion')
with gr.Accordion('Instructions', open=False):
gr.Markdown('1. Upload src image and draw mask')
gr.Markdown('2. Upload tgt image')
gr.Markdown('3. Input name of tgt object and description')
gr.Markdown('4. Click `Generate` when it is ready!')
with gr.Group():
with gr.Box():
with gr.Column():
with gr.Row() as main_blocks:
#
with gr.Column() as step_1:
gr.Markdown('### Source Input and Add Mask')
image = gr.Image(source='upload',
shape=[HEIGHT,WIDTH],
type='pil',#numpy',
elem_classes="image_upload",
label='Source Image',
tool='sketch',
brush_radius=60).style(height=500)
src_input=image
text_prompt = gr.Textbox(label='Prompt')
run_button = gr.Button(label='Generate', value='Generate', variant="primary")
#
with gr.Column() as step_2:
gr.Markdown('### Target Input')
target = gr.Image(source='upload',
shape=[HEIGHT,WIDTH],
type='pil',#numpy',
elem_classes="image_upload",
label='Target Image'
).style(height=500)
tgt_input=target
style_subject = gr.Textbox(label='Target Object')
with gr.Row() as output_blocks:
with gr.Column() as output_step:
gr.Markdown('### Output')
output_image = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="output_image",
).style(height=500,containter=True)
with gr.Accordion('Advanced options', open=False):
num_inference_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=50,
step=1)
guidance_scale = gr.Slider(label='Text Guidance Scale',
minimum=0.1,
maximum=30.0,
value=7.5,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=2147483647,
step=1,
randomize=True)
# Model
inputs = [
src_input,
tgt_input,
text_prompt,
style_subject,
num_inference_steps,
guidance_scale,
seed,
]
def generate(src_input,
tgt_input,
text_prompt,
style_subject,
num_inference_steps,
guidance_scale,
seed,
):
if src_input is None or tgt_input is None:
gr.Error("You must upload an image first.")
return {output_image : None,}
# model part
tgt_subject = style_subject
generator = torch.Generator(device="cpu").manual_seed(seed)
init_image = src_input['image']
cldm_cond_image = src_input['mask']
control_image = make_inpaint_condition(init_image, cldm_cond_image)
style_image = tgt_input
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
output = blip_diffusion_pipe(
text_prompt,
style_image,
control_image,
style_subject,
tgt_subject,
generator=generator,
image=init_image,
mask_image=cldm_cond_image,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
neg_prompt=negative_prompt,
height=HEIGHT,
width=WIDTH,
).images
return {output_image : output,}
run_button.click(fn=generate, inputs=inputs, outputs=[output_image])
return demo
if __name__ == '__main__':
demo = create_demo()
demo.queue().launch()