import gradio as gr import torch from diffusers import ( AutoPipelineForText2Image, StableDiffusionXLControlNetPipeline, DiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionAdapterPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLAdapterPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, ControlNetModel, T2IAdapter, ) import time import utils dtype = torch.float16 device = torch.device("cuda") # pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile examples = [["SD T2I", 4, True, True], ["Würstchen (T2I)", 4, False, True]] pipeline_mapping = { "SD T2I": (DiffusionPipeline, "runwayml/stable-diffusion-v1-5"), "SD I2I": (StableDiffusionImg2ImgPipeline, "runwayml/stable-diffusion-v1-5"), "SD Inpainting": ( StableDiffusionInpaintPipeline, "runwayml/stable-diffusion-inpainting", ), "SD ControlNet": ( StableDiffusionControlNetPipeline, "runwayml/stable-diffusion-v1-5", "lllyasviel/sd-controlnet-canny", ), "SD T2I Adapters": ( StableDiffusionAdapterPipeline, "CompVis/stable-diffusion-v1-4" "TencentARC/t2iadapter_canny_sd14v1", ), "SDXL T2I": (DiffusionPipeline, "stabilityai/stable-diffusion-xl-base-1.0"), "SDXL I2I": ( StableDiffusionXLImg2ImgPipeline, "stabilityai/stable-diffusion-xl-base-1.0", ), "SDXL Inpainting": ( StableDiffusionXLInpaintPipeline, "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", ), "SDXL ControlNet": ( StableDiffusionXLControlNetPipeline, "stabilityai/stable-diffusion-xl-base-1.0", "diffusers/controlnet-canny-sdxl-1.0", ), "SDXL T2I Adapters": ( StableDiffusionXLAdapterPipeline, "stabilityai/stable-diffusion-xl-base-1.0", "TencentARC/t2i-adapter-canny-sdxl-1.0", ), "Kandinsky 2.2 (T2I)": ( AutoPipelineForText2Image, "kandinsky-community/kandinsky-2-2-decoder", ), "Würstchen (T2I)": (AutoPipelineForText2Image, "warp-ai/wuerstchen"), } def load_pipeline( pipeline_to_benchmark: str, use_channels_last: bool = False, do_torch_compile: bool = False, ): # Get pipeline details. print(f"Loading pipeline: {pipeline_to_benchmark}") pipeline_details = pipeline_mapping[pipeline_to_benchmark] pipeline_cls = pipeline_details[0] pipeline_ckpt = pipeline_details[1] # Load adapter if needed. if "ControlNet" in pipeline_to_benchmark: controlnet_ckpt = pipeline_details[2] controlnet = ControlNetModel.from_pretrained( controlnet_ckpt, variant="fp16", torch_dtype=torch.float16 ).to(device) elif "Adapters" in pipeline_to_benchmark: adapter_clpt = pipeline_details[2] adapter = T2IAdapter.from_pretrained( adapter_clpt, variant="fp16", torch_dtype=torch.float16 ).to(device) # Load pipeline. if ( "ControlNet" not in pipeline_to_benchmark or "Adapters" not in pipeline_to_benchmark ): pipeline = pipeline_cls.from_pretrained( pipeline_ckpt, variant="fp16", torch_dtype=dtype ) elif "ControlNet" in pipeline_to_benchmark: pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, controlnet=controlnet) elif "Adapters" in pipeline_to_benchmark: pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, adapter=adapter) pipeline.to(device) # Optionally set memory layout. if use_channels_last: print("Setting memory layout.") if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]: pipeline.unet.to(memory_format=torch.channels_last) elif pipeline_to_benchmark == "Würstchen (T2I)": pipeline.prior.to(memory_format=torch.channels_last) pipeline.decoder.to(memory_format=torch.channels_last) elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)": pipeline.unet.to(memory_format=torch.channels_last) if hasattr(pipeline, "controlnet"): pipeline.controlnet.to(memory_format=torch.channels_last) elif hasattr(pipeline, "adapter"): pipeline.adapter.to(memory_format=torch.channels_last) # Optional torch compilation. if do_torch_compile: print("Compiling pipeline.") if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]: pipeline.unet = torch.compile( pipeline.unet, mode="reduce-overhead", fullgraph=True ) elif pipeline_to_benchmark == "Würstchen (T2I)": pipeline.prior = torch.compile( pipeline.prior, mode="reduce-overhead", fullgraph=True ) pipeline.decoder = torch.compile( pipeline.decoder, mode="reduce-overhead", fullgraph=True ) elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)": pipeline.unet = torch.compile( pipeline.unet, mode="reduce-overhead", fullgraph=True ) if hasattr(pipeline, "controlnet"): pipeline.controlnet = torch.compile( pipeline.controlnet, mode="reduce-overhead", fullgraph=True ) elif hasattr(pipeline, "adapter"): pipeline.adapter = torch.compile( pipeline.adapter, mode="reduce-overhead", fullgraph=True ) print("Pipeline loaded.") return pipeline def generate( pipeline_to_benchmark: str, num_images_per_prompt: int = 1, use_channels_last: bool = False, do_torch_compile: bool = False, ): if isinstance(pipeline_to_benchmark, list): # It can only happen when we don't select a pipeline to benchmark. raise ValueError( "pipeline_to_benchmark cannot be None. Please select a pipeline to benchmark." ) print("Start...") print("Torch version", torch.__version__) print("Torch CUDA version", torch.version.cuda) pipeline = load_pipeline( pipeline_to_benchmark=pipeline_to_benchmark, use_channels_last=use_channels_last, do_torch_compile=do_torch_compile, ) for _ in range(3): prompt = 77 * "a" num_inference_steps = 20 call_args = dict( prompt=prompt, num_images_per_prompt=num_images_per_prompt, num_inference_steps=num_inference_steps, ) if pipeline_to_benchmark in ["SD I2I", "SDXL I2I"]: image = utils.get_image_for_img_to_img(pipeline_to_benchmark) call_args.update({"image": image}) elif "Inpainting" in pipeline_to_benchmark: image, mask_image = utils.get_image_from_inpainting(pipeline_to_benchmark) call_args.update({"image": image, "mask_image": mask_image}) elif "ControlNet" in pipeline_to_benchmark: image = utils.get_image_for_controlnet(pipeline_to_benchmark) call_args.update({"image": image}) elif "Adapters" in pipeline_to_benchmark: image = utils.get_image_for_adapters(pipeline_to_benchmark) call_args.update({"image": image}) start_time = time.time() _ = pipeline(**call_args).images end_time = time.time() print(f"For {num_inference_steps} steps", end_time - start_time) print("Avg per step", (end_time - start_time) / num_inference_steps) return f"Avg per step: {((end_time - start_time) / num_inference_steps):.4f} seconds." with gr.Blocks(css="style.css") as demo: do_torch_compile = gr.Checkbox(label="Enable torch.compile()?") use_channels_last = gr.Checkbox(label="Use `channels_last` memory layout?") pipeline_to_benchmark = gr.Dropdown( list(pipeline_mapping.keys()), value=None, multiselect=False, label="Pipeline to benchmark", ) batch_size = gr.Slider( label="Number of images per prompt", minimum=1, maximum=16, step=1, value=1, ) btn = gr.Button("Benchmark!").style( margin=False, rounded=(False, True, True, False), full_width=False, ) result = gr.Text(label="Result") gr.Examples( examples=examples, inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile], outputs=result, fn=generate, cache_examples=True, ) btn.click( fn=generate, inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile], outputs=result, ) demo.launch(show_error=True)