sayakpaul HF staff commited on
Commit
9d87a5d
1 Parent(s): f11095a

more updates

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -20,7 +20,9 @@ import utils
20
 
21
  dtype = torch.float16
22
  device = torch.device("cuda")
23
- examples = [[True, True, "SD T2I", 4], [False, True, "Würstchen (T2I)", 4]]
 
 
24
 
25
  pipeline_mapping = {
26
  "SD T2I": (DiffusionPipeline, "runwayml/stable-diffusion-v1-5"),
@@ -202,6 +204,8 @@ def generate(
202
  print(f"For {num_inference_steps} steps", end_time - start_time)
203
  print("Avg per step", (end_time - start_time) / num_inference_steps)
204
 
 
 
205
 
206
  with gr.Blocks(css="style.css") as demo:
207
  do_torch_compile = gr.Checkbox(label="Enable torch.compile()?")
@@ -225,11 +229,12 @@ with gr.Blocks(css="style.css") as demo:
225
  rounded=(False, True, True, False),
226
  full_width=False,
227
  )
 
228
 
229
  gr.Examples(
230
  examples=examples,
231
  inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile],
232
- outputs="text",
233
  fn=generate,
234
  cache_examples=True,
235
  )
@@ -237,7 +242,7 @@ with gr.Blocks(css="style.css") as demo:
237
  btn.click(
238
  fn=generate,
239
  inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile],
240
- outputs="text"
241
  )
242
 
243
  demo.launch(show_error=True)
 
20
 
21
  dtype = torch.float16
22
  device = torch.device("cuda")
23
+
24
+ # pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile
25
+ examples = [["SD T2I", 4, True, True], ["Würstchen (T2I)", 4, False, True]]
26
 
27
  pipeline_mapping = {
28
  "SD T2I": (DiffusionPipeline, "runwayml/stable-diffusion-v1-5"),
 
204
  print(f"For {num_inference_steps} steps", end_time - start_time)
205
  print("Avg per step", (end_time - start_time) / num_inference_steps)
206
 
207
+ return f"Avg per step: {(end_time - start_time) / num_inference_steps} seconds."
208
+
209
 
210
  with gr.Blocks(css="style.css") as demo:
211
  do_torch_compile = gr.Checkbox(label="Enable torch.compile()?")
 
229
  rounded=(False, True, True, False),
230
  full_width=False,
231
  )
232
+ result = gr.Text(label="Result")
233
 
234
  gr.Examples(
235
  examples=examples,
236
  inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile],
237
+ outputs=result,
238
  fn=generate,
239
  cache_examples=True,
240
  )
 
242
  btn.click(
243
  fn=generate,
244
  inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile],
245
+ outputs=result,
246
  )
247
 
248
  demo.launch(show_error=True)