sayakpaul HF staff commited on
Commit
32e7fe7
1 Parent(s): 8342365

fix: device placements.

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -82,13 +82,11 @@ def load_pipeline(
82
  if "ControlNet" in pipeline_to_benchmark:
83
  controlnet_ckpt = pipeline_details[2]
84
  controlnet = ControlNetModel.from_pretrained(
85
- controlnet_ckpt, torch_dtype=torch.float16
86
  ).to(device)
87
  elif "Adapters" in pipeline_to_benchmark:
88
  adapter_clpt = pipeline_details[2]
89
- adapter = T2IAdapter.from_pretrained(
90
- adapter_clpt, torch_dtype=torch.float16
91
- ).to(device)
92
 
93
  # Load pipeline.
94
  if (
@@ -98,9 +96,13 @@ def load_pipeline(
98
  pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, torch_dtype=dtype)
99
 
100
  elif "ControlNet" in pipeline_to_benchmark:
101
- pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, controlnet=controlnet)
 
 
102
  elif "Adapters" in pipeline_to_benchmark:
103
- pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, adapter=adapter)
 
 
104
 
105
  pipeline.to(device)
106
 
 
82
  if "ControlNet" in pipeline_to_benchmark:
83
  controlnet_ckpt = pipeline_details[2]
84
  controlnet = ControlNetModel.from_pretrained(
85
+ controlnet_ckpt, torch_dtype=dtype
86
  ).to(device)
87
  elif "Adapters" in pipeline_to_benchmark:
88
  adapter_clpt = pipeline_details[2]
89
+ adapter = T2IAdapter.from_pretrained(adapter_clpt, torch_dtype=dtype).to(device)
 
 
90
 
91
  # Load pipeline.
92
  if (
 
96
  pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, torch_dtype=dtype)
97
 
98
  elif "ControlNet" in pipeline_to_benchmark:
99
+ pipeline = pipeline_cls.from_pretrained(
100
+ pipeline_ckpt, controlnet=controlnet, torch_dtype=dtype
101
+ )
102
  elif "Adapters" in pipeline_to_benchmark:
103
+ pipeline = pipeline_cls.from_pretrained(
104
+ pipeline_ckpt, adapter=adapter, torch_dtype=dtype
105
+ )
106
 
107
  pipeline.to(device)
108