sayakpaul HF staff commited on
Commit
fd14e0a
1 Parent(s): 57f479a

add support for other pipelines too.

Browse files
Files changed (2) hide show
  1. app.py +25 -8
  2. utils.py +1 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from diffusers import (
 
4
  StableDiffusionXLControlNetPipeline,
5
  DiffusionPipeline,
6
  StableDiffusionImg2ImgPipeline,
@@ -55,6 +56,8 @@ pipeline_mapping = {
55
  "stabilityai/stable-diffusion-xl-base-1.0",
56
  "TencentARC/t2i-adapter-canny-sdxl-1.0",
57
  ),
 
 
58
  }
59
 
60
 
@@ -93,22 +96,36 @@ def load_pipeline(
93
  pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, controlnet=controlnet)
94
  elif "Adapters" in pipeline_to_benchmark:
95
  pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, adapter=adapter)
 
96
  pipeline.to(device)
97
 
98
  # Optionally set memory layout.
99
  if use_channels_last:
100
- pipeline.unet.to(memory_format=torch.channels_last)
 
 
 
 
 
 
101
 
102
- if hasattr(pipeline, "controlnet"):
103
- pipeline.controlnet.to(memory_format=torch.channels_last)
104
- elif hasattr(pipeline, "adapter"):
105
- pipeline.adapter.to(memory_format=torch.channels_last)
106
 
107
  # Optional torch compilation.
108
  if do_torch_compile:
109
- pipeline.unet = torch.compile(
110
- pipeline.unet, mode="reduce-overhead", fullgraph=True
111
- )
 
 
 
 
 
 
 
112
  if hasattr(pipeline, "controlnet"):
113
  pipeline.controlnet = torch.compile(
114
  pipeline.controlnet, mode="reduce-overhead", fullgraph=True
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import (
4
+ AutoPipelineForText2Image,
5
  StableDiffusionXLControlNetPipeline,
6
  DiffusionPipeline,
7
  StableDiffusionImg2ImgPipeline,
 
56
  "stabilityai/stable-diffusion-xl-base-1.0",
57
  "TencentARC/t2i-adapter-canny-sdxl-1.0",
58
  ),
59
+ "Kandinsky 2.2 (T2I)": (AutoPipelineForText2Image, "kandinsky-community/kandinsky-2-2-decoder"),
60
+ "Würstchen (T2I)": (AutoPipelineForText2Image, "warp-ai/wuerstchen")
61
  }
62
 
63
 
 
96
  pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, controlnet=controlnet)
97
  elif "Adapters" in pipeline_to_benchmark:
98
  pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, adapter=adapter)
99
+
100
  pipeline.to(device)
101
 
102
  # Optionally set memory layout.
103
  if use_channels_last:
104
+ if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]:
105
+ pipeline.unet.to(memory_format=torch.channels_last)
106
+ elif pipeline_to_benchmark == "Würstchen (T2I)":
107
+ pipeline.prior.to(memory_format=torch.channels_last)
108
+ pipeline.decoder.to(memory_format=torch.channels_last)
109
+ elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)":
110
+ pipeline.unet.to(memory_format=torch.channels_last)
111
 
112
+ if hasattr(pipeline, "controlnet"):
113
+ pipeline.controlnet.to(memory_format=torch.channels_last)
114
+ elif hasattr(pipeline, "adapter"):
115
+ pipeline.adapter.to(memory_format=torch.channels_last)
116
 
117
  # Optional torch compilation.
118
  if do_torch_compile:
119
+ if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]:
120
+ pipeline.unet = torch.compile(
121
+ pipeline.unet, mode="reduce-overhead", fullgraph=True
122
+ )
123
+ elif pipeline_to_benchmark == "Würstchen (T2I)":
124
+ pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True)
125
+ pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)
126
+ elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)":
127
+ pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
128
+
129
  if hasattr(pipeline, "controlnet"):
130
  pipeline.controlnet = torch.compile(
131
  pipeline.controlnet, mode="reduce-overhead", fullgraph=True
utils.py CHANGED
@@ -5,8 +5,7 @@ def get_image_for_img_to_img(pipeline_to_benchmark):
5
  if pipeline_to_benchmark == "SD I2I":
6
  url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
7
  init_image = load_image(url).convert("RGB")
8
- size = (768, 512)
9
- init_image = init_image.resize(size)
10
  elif pipeline_to_benchmark == "SDXL I2I":
11
  url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-img2img.png"
12
  init_image = load_image(url).convert("RGB")
 
5
  if pipeline_to_benchmark == "SD I2I":
6
  url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
7
  init_image = load_image(url).convert("RGB")
8
+ init_image = init_image.resize((512, 512))
 
9
  elif pipeline_to_benchmark == "SDXL I2I":
10
  url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-img2img.png"
11
  init_image = load_image(url).convert("RGB")