hysts commited on
Commit
99875b7
1 Parent(s): 8c83fc4
Files changed (1) hide show
  1. inference.py +3 -0
inference.py CHANGED
@@ -9,6 +9,7 @@ import gradio as gr
9
  import imageio
10
  import PIL.Image
11
  import torch
 
12
  from einops import rearrange
13
  from huggingface_hub import ModelCard
14
 
@@ -65,6 +66,8 @@ class InferencePipeline:
65
  torch_dtype=torch.float16,
66
  use_auth_token=self.hf_token)
67
  pipe = pipe.to(self.device)
 
 
68
  self.pipe = pipe
69
  self.model_id = model_id # type: ignore
70
 
 
9
  import imageio
10
  import PIL.Image
11
  import torch
12
+ from diffusers.utils.import_utils import is_xformers_available
13
  from einops import rearrange
14
  from huggingface_hub import ModelCard
15
 
 
66
  torch_dtype=torch.float16,
67
  use_auth_token=self.hf_token)
68
  pipe = pipe.to(self.device)
69
+ if is_xformers_available():
70
+ pipe.unet.enable_xformers_memory_efficient_attention()
71
  self.pipe = pipe
72
  self.model_id = model_id # type: ignore
73