benchmark-pt2.1 / app.py
sayakpaul's picture
sayakpaul HF staff
fix
1cc3a64
import gradio as gr
import torch
from diffusers import (
AutoPipelineForText2Image,
StableDiffusionXLControlNetPipeline,
DiffusionPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionAdapterPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
ControlNetModel,
T2IAdapter,
)
import time
import utils
dtype = torch.float16
device = torch.device("cuda")
# pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile
# examples = [["SD T2I", 4, True, True]]
pipeline_mapping = {
"SD T2I": (DiffusionPipeline, "runwayml/stable-diffusion-v1-5"),
"SD I2I": (StableDiffusionImg2ImgPipeline, "runwayml/stable-diffusion-v1-5"),
"SD Inpainting": (
StableDiffusionInpaintPipeline,
"runwayml/stable-diffusion-inpainting",
),
"SD ControlNet": (
StableDiffusionControlNetPipeline,
"runwayml/stable-diffusion-v1-5",
"lllyasviel/sd-controlnet-canny",
),
"SD T2I Adapters": (
StableDiffusionAdapterPipeline,
"CompVis/stable-diffusion-v1-4",
"TencentARC/t2iadapter_canny_sd14v1",
),
"SDXL T2I": (DiffusionPipeline, "stabilityai/stable-diffusion-xl-base-1.0"),
"SDXL I2I": (
StableDiffusionXLImg2ImgPipeline,
"stabilityai/stable-diffusion-xl-base-1.0",
),
"SDXL Inpainting": (
StableDiffusionXLInpaintPipeline,
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
),
"SDXL ControlNet": (
StableDiffusionXLControlNetPipeline,
"stabilityai/stable-diffusion-xl-base-1.0",
"diffusers/controlnet-canny-sdxl-1.0",
),
"SDXL T2I Adapters": (
StableDiffusionXLAdapterPipeline,
"stabilityai/stable-diffusion-xl-base-1.0",
"TencentARC/t2i-adapter-canny-sdxl-1.0",
),
"Kandinsky 2.2 (T2I)": (
AutoPipelineForText2Image,
"kandinsky-community/kandinsky-2-2-decoder",
),
"Würstchen (T2I)": (AutoPipelineForText2Image, "warp-ai/wuerstchen"),
}
def load_pipeline(
pipeline_to_benchmark: str,
use_channels_last: bool = False,
do_torch_compile: bool = False,
):
# Get pipeline details.
print(f"Loading pipeline: {pipeline_to_benchmark}")
pipeline_details = pipeline_mapping[pipeline_to_benchmark]
pipeline_cls = pipeline_details[0]
pipeline_ckpt = pipeline_details[1]
# Load adapter if needed.
if "ControlNet" in pipeline_to_benchmark:
controlnet_ckpt = pipeline_details[2]
controlnet = ControlNetModel.from_pretrained(
controlnet_ckpt, torch_dtype=dtype
).to(device)
elif "Adapters" in pipeline_to_benchmark:
adapter_clpt = pipeline_details[2]
adapter = T2IAdapter.from_pretrained(adapter_clpt, torch_dtype=dtype).to(device)
# Load pipeline.
if (
"ControlNet" not in pipeline_to_benchmark
and "Adapters" not in pipeline_to_benchmark
):
pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, torch_dtype=dtype)
elif "ControlNet" in pipeline_to_benchmark:
pipeline = pipeline_cls.from_pretrained(
pipeline_ckpt, controlnet=controlnet, torch_dtype=dtype
)
elif "Adapters" in pipeline_to_benchmark:
pipeline = pipeline_cls.from_pretrained(
pipeline_ckpt, adapter=adapter, torch_dtype=dtype
)
pipeline.to(device)
# Optionally set memory layout.
if use_channels_last:
print("Setting memory layout.")
if pipeline_to_benchmark != "Würstchen (T2I)":
pipeline.unet.to(memory_format=torch.channels_last)
elif pipeline_to_benchmark == "Würstchen (T2I)":
pipeline.prior_prior.to(memory_format=torch.channels_last)
pipeline.decoder.to(memory_format=torch.channels_last)
if hasattr(pipeline, "controlnet"):
pipeline.controlnet.to(memory_format=torch.channels_last)
elif hasattr(pipeline, "adapter"):
pipeline.adapter.to(memory_format=torch.channels_last)
# Optional torch compilation.
if do_torch_compile:
print("Compiling pipeline.")
if pipeline_to_benchmark != "Würstchen (T2I)":
pipeline.unet = torch.compile(
pipeline.unet, mode="reduce-overhead", fullgraph=True
)
elif pipeline_to_benchmark == "Würstchen (T2I)":
pipeline.prior_prior = torch.compile(
pipeline.prior_prior, mode="reduce-overhead", fullgraph=True
)
pipeline.decoder = torch.compile(
pipeline.decoder, mode="reduce-overhead", fullgraph=True
)
if hasattr(pipeline, "controlnet"):
pipeline.controlnet = torch.compile(
pipeline.controlnet, mode="reduce-overhead", fullgraph=True
)
elif hasattr(pipeline, "adapter"):
pipeline.adapter = torch.compile(
pipeline.adapter, mode="reduce-overhead", fullgraph=True
)
print("Pipeline loaded.")
pipeline.set_progress_bar_config(disable=True)
return pipeline
def generate(
pipeline_to_benchmark: str,
num_images_per_prompt: int = 1,
use_channels_last: bool = False,
do_torch_compile: bool = False,
):
if isinstance(pipeline_to_benchmark, list):
# It can only happen when we don't select a pipeline to benchmark.
raise ValueError(
"pipeline_to_benchmark cannot be None. Please select a pipeline to benchmark."
)
print("Start...")
print("Torch version", torch.__version__)
print("Torch CUDA version", torch.version.cuda)
pipeline = load_pipeline(
pipeline_to_benchmark=pipeline_to_benchmark,
use_channels_last=use_channels_last,
do_torch_compile=do_torch_compile,
)
for _ in range(3):
prompt = 77 * "a"
num_inference_steps = 20
call_args = dict(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps,
)
if pipeline_to_benchmark in ["SD I2I", "SDXL I2I"]:
image = utils.get_image_for_img_to_img(pipeline_to_benchmark)
call_args.update({"image": image})
elif "Inpainting" in pipeline_to_benchmark:
image, mask_image = utils.get_image_for_inpainting(pipeline_to_benchmark)
call_args.update({"image": image, "mask_image": mask_image})
elif "ControlNet" in pipeline_to_benchmark:
image = utils.get_image_for_controlnet(pipeline_to_benchmark)
call_args.update({"image": image})
elif "Adapters" in pipeline_to_benchmark:
image = utils.get_image_for_adapters(pipeline_to_benchmark)
call_args.update({"image": image})
start_time = time.time()
_ = pipeline(**call_args).images
end_time = time.time()
print(f"For {num_inference_steps} steps", end_time - start_time)
print("Avg per step", (end_time - start_time) / num_inference_steps)
return (
f"Avg per step: {((end_time - start_time) / num_inference_steps):.4f} seconds."
)
with gr.Blocks(css="style.css") as demo:
do_torch_compile = gr.Checkbox(label="Enable torch.compile()?")
use_channels_last = gr.Checkbox(label="Use `channels_last` memory layout?")
pipeline_to_benchmark = gr.Dropdown(
list(pipeline_mapping.keys()),
value=None,
multiselect=False,
label="Pipeline to benchmark",
)
batch_size = gr.Slider(
label="Number of images per prompt",
minimum=1,
maximum=16,
step=1,
value=1,
)
btn = gr.Button("Benchmark!").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
result = gr.Text(label="Result")
# gr.Examples(
# examples=examples,
# inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile],
# outputs=result,
# fn=generate,
# cache_examples=True,
# )
btn.click(
fn=generate,
inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile],
outputs=result,
)
demo.launch(show_error=True)