File size: 2,917 Bytes
1887c49
a3e9651
 
9b8f82d
bf459df
c3a473c
a3e9651
 
 
 
 
 
d1118c6
 
 
 
 
 
 
 
1fe5ec5
ff4f40d
 
 
 
 
 
 
1fe5ec5
d1118c6
a3e9651
 
c3a473c
1fe5ec5
90792d2
d1118c6
90792d2
a3e9651
90792d2
a3e9651
 
 
 
 
c3a473c
1fe5ec5
90792d2
d1118c6
90792d2
a3e9651
90792d2
a3e9651
 
 
 
 
c3a473c
1fe5ec5
90792d2
91dbe72
90792d2
a3e9651
6631477
 
67df9db
641b3df
 
 
 
 
 
 
 
6aab45f
6631477
a3e9651
 
6631477
67df9db
6631477
a3e9651
6631477
 
 
 
ad96459
6631477
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
from diffusers import DiffusionPipeline
import spaces
import torch
from concurrent.futures import ProcessPoolExecutor
from huggingface_hub import hf_hub_download

dev_model = "black-forest-labs/FLUX.1-dev"
schnell_model = "black-forest-labs/FLUX.1-schnell"

device = "cuda" if torch.cuda.is_available() else "cpu"

repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
hyper_lora = hf_hub_download(repo_name, ckpt_name)

repo_name = "alimama-creative/FLUX.1-Turbo-Alpha"
ckpt_name = "diffusion_pytorch_model.safetensors"
turbo_lora = hf_hub_download(repo_name, ckpt_name)

pipe_dev = DiffusionPipeline.from_pretrained(dev_model, torch_dtype=torch.bfloat16)
pipe_schnell = DiffusionPipeline.from_pretrained(
    schnell_model,
    text_encoder=pipe_dev.text_encoder,
    text_encoder_2=pipe_dev.text_encoder_2,
    tokenizer=pipe_dev.tokenizer,
    tokenizer_2=pipe_dev.tokenizer_2,
    torch_dtype=torch.bfloat16
)

@spaces.GPU
def run_dev_hyper(prompt):
    print("dev_hyper")
    pipe_dev.to("cuda")
    print(hyper_lora)
    pipe_dev.load_lora_weights(hyper_lora)
    print("Loaded hyper lora!")
    image = pipe_dev(prompt, num_inference_steps=8, joint_attention_kwargs={"scale": 0.125}).images[0]
    print("Ran!")
    pipe_dev.unload_lora_weights()
    return image

@spaces.GPU
def run_dev_turbo(prompt):
    print("dev_turbo")
    pipe_dev.to("cuda")
    print(turbo_lora)
    pipe_dev.load_lora_weights(turbo_lora)
    print("Loaded turbo lora!")
    image = pipe_dev(prompt, num_inference_steps=8).images[0]
    print("Ran!")
    pipe_dev.unload_lora_weights()
    return image

@spaces.GPU
def run_schnell(prompt):
    print("schnell")
    pipe_schnell.to("cuda")
    print("schnell on gpu")
    image = pipe_schnell(prompt, num_inference_steps=4).images[0]
    print("Ran!")
    return image

def run_parallel_models(prompt):
    print(prompt)
    with ProcessPoolExecutor(max_workers=3) as executor:
        future_dev_hyper = executor.submit(run_dev_hyper, prompt)
        future_dev_turbo = executor.submit(run_dev_turbo, prompt)
        future_schnell = executor.submit(run_schnell, prompt)
        
        res_dev_hyper = future_dev_hyper.result()
        res_dev_turbo = future_dev_turbo.result()
        res_schnell = future_schnell.result()
    return res_dev_hyper, res_dev_turbo, res_schnell

run_parallel_models.zerogpu = True

with gr.Blocks() as demo:
    gr.Markdown("# Low Step Flux Comparison")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
        submit = gr.Button()
    with gr.Row():
        schnell = gr.Image(label="FLUX Schnell (4 steps)")
        hyper = gr.Image(label="FLUX.1[dev] HyperFLUX (8 steps)")
        turbo = gr.Image(label="FLUX.1[dev]-Turbo-Alpha (8 steps)")

    submit.click(
        fn=run_parallel_models,
        inputs=[prompt],
        outputs=[schnell, hyper, turbo]
    )
demo.launch()