multimodalart HF staff commited on
Commit
a3e9651
1 Parent(s): ad96459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -2
app.py CHANGED
@@ -1,12 +1,53 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def run_parallel_models(prompt):
 
 
 
 
 
 
4
  return gr.update(), gr.update(), gr.update()
5
 
 
 
6
  with gr.Blocks() as demo:
7
- gr.Markdown("#Fast Flux Comparison")
8
  with gr.Row():
9
- prompt = gr.Textbox()
10
  submit = gr.Button()
11
  with gr.Row():
12
  schnell = gr.Image(label="FLUX Schnell (4 steps)")
 
1
  import gradio as gr
2
+ from diffusers import DiffusionPipeline
3
+ import spaces
4
+
5
+ dev_model = "black-forest-labs/FLUX.1-dev"
6
+ schnell_model = "black-forest-labs/FLUX.1-schnell"
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ pipe_dev = DiffusionPipeline.from_pretrained(dev_model, torch_dtype=torch.bfloat16).to(device)
11
+ pipe_schnell = DiffusionPipeline.from_pretrained(schnell_model, torch_dtype=torch.bfloat16).to(device)
12
+
13
+ @spaces.GPU
14
+ def run_dev_hyper(prompt):
15
+ repo_name = "ByteDance/Hyper-SD"
16
+ ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
17
+ pipe_dev.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
18
+ image = pipe_dev(prompt, num_inference_steps=8, joint_attention_kwargs={"scale": 0.125}).images[0]
19
+ pipe_dev.unload_lora_weights()
20
+ return image
21
+
22
+ @spaces.GPU
23
+ def run_dev_turbo(prompt):
24
+ repo_name = "alimama-creative/FLUX.1-Turbo-Alpha"
25
+ ckpt_name = "diffusion_pytorch_model.safetensors"
26
+ pipe_dev.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
27
+ image = pipe_dev(prompt, num_inference_steps=8).images[0]
28
+ pipe_dev.unload_lora_weights()
29
+ return image
30
+
31
+ @spaces.GPU
32
+ def run_schnell(prompt):
33
+ image = pipe_schnell(prompt).images[0]
34
+ return image
35
 
36
  def run_parallel_models(prompt):
37
+
38
+ with ProcessPoolExecutor(3) as e:
39
+ image_dev_hyper = run_dev_hyper(prompt)
40
+ image_dev_turbo = run_dev_turbo(prompt)
41
+ image_schnell = run_schnell(prompt)
42
+
43
  return gr.update(), gr.update(), gr.update()
44
 
45
+ run_parallel_models.zerogpu = True
46
+
47
  with gr.Blocks() as demo:
48
+ gr.Markdown("# Fast Flux Comparison")
49
  with gr.Row():
50
+ prompt = gr.Textbox(label="Prompt")
51
  submit = gr.Button()
52
  with gr.Row():
53
  schnell = gr.Image(label="FLUX Schnell (4 steps)")