sayakpaul HF staff commited on
Commit
1cc3a64
1 Parent(s): 84892c9
Files changed (1) hide show
  1. app.py +2 -8
app.py CHANGED
@@ -110,13 +110,11 @@ def load_pipeline(
110
  # Optionally set memory layout.
111
  if use_channels_last:
112
  print("Setting memory layout.")
113
- if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]:
114
  pipeline.unet.to(memory_format=torch.channels_last)
115
  elif pipeline_to_benchmark == "Würstchen (T2I)":
116
  pipeline.prior_prior.to(memory_format=torch.channels_last)
117
  pipeline.decoder.to(memory_format=torch.channels_last)
118
- elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)":
119
- pipeline.unet.to(memory_format=torch.channels_last)
120
 
121
  if hasattr(pipeline, "controlnet"):
122
  pipeline.controlnet.to(memory_format=torch.channels_last)
@@ -126,7 +124,7 @@ def load_pipeline(
126
  # Optional torch compilation.
127
  if do_torch_compile:
128
  print("Compiling pipeline.")
129
- if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]:
130
  pipeline.unet = torch.compile(
131
  pipeline.unet, mode="reduce-overhead", fullgraph=True
132
  )
@@ -137,10 +135,6 @@ def load_pipeline(
137
  pipeline.decoder = torch.compile(
138
  pipeline.decoder, mode="reduce-overhead", fullgraph=True
139
  )
140
- elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)":
141
- pipeline.unet = torch.compile(
142
- pipeline.unet, mode="reduce-overhead", fullgraph=True
143
- )
144
 
145
  if hasattr(pipeline, "controlnet"):
146
  pipeline.controlnet = torch.compile(
 
110
  # Optionally set memory layout.
111
  if use_channels_last:
112
  print("Setting memory layout.")
113
+ if pipeline_to_benchmark != "Würstchen (T2I)":
114
  pipeline.unet.to(memory_format=torch.channels_last)
115
  elif pipeline_to_benchmark == "Würstchen (T2I)":
116
  pipeline.prior_prior.to(memory_format=torch.channels_last)
117
  pipeline.decoder.to(memory_format=torch.channels_last)
 
 
118
 
119
  if hasattr(pipeline, "controlnet"):
120
  pipeline.controlnet.to(memory_format=torch.channels_last)
 
124
  # Optional torch compilation.
125
  if do_torch_compile:
126
  print("Compiling pipeline.")
127
+ if pipeline_to_benchmark != "Würstchen (T2I)":
128
  pipeline.unet = torch.compile(
129
  pipeline.unet, mode="reduce-overhead", fullgraph=True
130
  )
 
135
  pipeline.decoder = torch.compile(
136
  pipeline.decoder, mode="reduce-overhead", fullgraph=True
137
  )
 
 
 
 
138
 
139
  if hasattr(pipeline, "controlnet"):
140
  pipeline.controlnet = torch.compile(