multimodalart HF staff commited on
Commit
1fe5ec5
1 Parent(s): ff4f40d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -9,7 +9,7 @@ schnell_model = "black-forest-labs/FLUX.1-schnell"
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- pipe_dev = DiffusionPipeline.from_pretrained(dev_model, torch_dtype=torch.bfloat16).to(device)
13
  pipe_schnell = DiffusionPipeline.from_pretrained(
14
  schnell_model,
15
  text_encoder=pipe_dev.text_encoder,
@@ -17,10 +17,10 @@ pipe_schnell = DiffusionPipeline.from_pretrained(
17
  tokenizer=pipe_dev.tokenizer,
18
  tokenizer_2=pipe_dev.tokenizer_2,
19
  torch_dtype=torch.bfloat16
20
- ).to(device)
21
-
22
  @spaces.GPU
23
  def run_dev_hyper(prompt):
 
24
  repo_name = "ByteDance/Hyper-SD"
25
  ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
26
  pipe_dev.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
@@ -30,6 +30,7 @@ def run_dev_hyper(prompt):
30
 
31
  @spaces.GPU
32
  def run_dev_turbo(prompt):
 
33
  repo_name = "alimama-creative/FLUX.1-Turbo-Alpha"
34
  ckpt_name = "diffusion_pytorch_model.safetensors"
35
  pipe_dev.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
@@ -39,6 +40,7 @@ def run_dev_turbo(prompt):
39
 
40
  @spaces.GPU
41
  def run_schnell(prompt):
 
42
  image = pipe_schnell(prompt).images[0]
43
  return image
44
 
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ pipe_dev = DiffusionPipeline.from_pretrained(dev_model, torch_dtype=torch.bfloat16)
13
  pipe_schnell = DiffusionPipeline.from_pretrained(
14
  schnell_model,
15
  text_encoder=pipe_dev.text_encoder,
 
17
  tokenizer=pipe_dev.tokenizer,
18
  tokenizer_2=pipe_dev.tokenizer_2,
19
  torch_dtype=torch.bfloat16
20
+ )
 
21
  @spaces.GPU
22
  def run_dev_hyper(prompt):
23
+ pipe_dev.to("cuda")
24
  repo_name = "ByteDance/Hyper-SD"
25
  ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
26
  pipe_dev.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
 
30
 
31
  @spaces.GPU
32
  def run_dev_turbo(prompt):
33
+ pipe_dev.to("cuda")
34
  repo_name = "alimama-creative/FLUX.1-Turbo-Alpha"
35
  ckpt_name = "diffusion_pytorch_model.safetensors"
36
  pipe_dev.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
 
40
 
41
  @spaces.GPU
42
  def run_schnell(prompt):
43
+ pipe_schnell.to("cuda")
44
  image = pipe_schnell(prompt).images[0]
45
  return image
46