File size: 3,907 Bytes
1887c49
a3e9651
 
9b8f82d
bf459df
c3a473c
a3e9651
 
 
 
 
 
d1118c6
 
 
 
 
 
 
 
1fe5ec5
ff4f40d
 
 
 
 
 
 
1fe5ec5
d1118c6
a3e9651
 
c3a473c
1fe5ec5
90792d2
d1118c6
90792d2
a3e9651
90792d2
a3e9651
 
 
 
 
c3a473c
1fe5ec5
90792d2
d1118c6
90792d2
a3e9651
90792d2
a3e9651
 
 
 
 
c3a473c
1fe5ec5
90792d2
91dbe72
90792d2
a3e9651
6631477
 
641b3df
 
 
 
 
 
fad334a
641b3df
fad334a
641b3df
fad334a
6631477
a3e9651
 
6631477
67df9db
fad334a
6631477
97606cf
 
 
b1db469
6631477
 
 
ad96459
b1db469
 
fad334a
 
 
 
 
 
 
b1db469
 
 
 
 
 
fad334a
 
6631477
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from diffusers import DiffusionPipeline
import spaces
import torch
from concurrent.futures import ProcessPoolExecutor
from huggingface_hub import hf_hub_download

dev_model = "black-forest-labs/FLUX.1-dev"
schnell_model = "black-forest-labs/FLUX.1-schnell"

device = "cuda" if torch.cuda.is_available() else "cpu"

repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
hyper_lora = hf_hub_download(repo_name, ckpt_name)

repo_name = "alimama-creative/FLUX.1-Turbo-Alpha"
ckpt_name = "diffusion_pytorch_model.safetensors"
turbo_lora = hf_hub_download(repo_name, ckpt_name)

pipe_dev = DiffusionPipeline.from_pretrained(dev_model, torch_dtype=torch.bfloat16)
pipe_schnell = DiffusionPipeline.from_pretrained(
    schnell_model,
    text_encoder=pipe_dev.text_encoder,
    text_encoder_2=pipe_dev.text_encoder_2,
    tokenizer=pipe_dev.tokenizer,
    tokenizer_2=pipe_dev.tokenizer_2,
    torch_dtype=torch.bfloat16
)

@spaces.GPU
def run_dev_hyper(prompt):
    print("dev_hyper")
    pipe_dev.to("cuda")
    print(hyper_lora)
    pipe_dev.load_lora_weights(hyper_lora)
    print("Loaded hyper lora!")
    image = pipe_dev(prompt, num_inference_steps=8, joint_attention_kwargs={"scale": 0.125}).images[0]
    print("Ran!")
    pipe_dev.unload_lora_weights()
    return image

@spaces.GPU
def run_dev_turbo(prompt):
    print("dev_turbo")
    pipe_dev.to("cuda")
    print(turbo_lora)
    pipe_dev.load_lora_weights(turbo_lora)
    print("Loaded turbo lora!")
    image = pipe_dev(prompt, num_inference_steps=8).images[0]
    print("Ran!")
    pipe_dev.unload_lora_weights()
    return image

@spaces.GPU
def run_schnell(prompt):
    print("schnell")
    pipe_schnell.to("cuda")
    print("schnell on gpu")
    image = pipe_schnell(prompt, num_inference_steps=4).images[0]
    print("Ran!")
    return image

def run_parallel_models(prompt):
    with ProcessPoolExecutor(max_workers=3) as executor:
        future_dev_hyper = executor.submit(run_dev_hyper, prompt)
        future_dev_turbo = executor.submit(run_dev_turbo, prompt)
        future_schnell = executor.submit(run_schnell, prompt)
        
        res_dev_hyper = future_dev_hyper.result()
        yield res_dev_hyper, gr.update(), gr.update()
        res_dev_turbo = future_dev_turbo.result()
        yield gr.update(), res_dev_turbo, gr.update()
        res_schnell = future_schnell.result()
        yield gr.update(), gr.update(), res_dev_turbo

run_parallel_models.zerogpu = True

with gr.Blocks() as demo:
    gr.Markdown("# Low Step Flux Comparison")
    gr.Markdown("Compare the quality (not the speed) of FLUX Schnell (4 steps), FLUX.1[dev] HyperFLUX (8 steps), FLUX.1[dev]-Turbo-Alpha (8 steps). It runs a bit slow as it's inferencing the three models.")
    with gr.Row():
        with gr.Column(scale=2):
            prompt = gr.Textbox(label="Prompt")
        with gr.Column(scale=1, min_width=120):
            submit = gr.Button("Run")
    with gr.Row():
        schnell = gr.Image(label="FLUX Schnell (4 steps)")
        hyper = gr.Image(label="FLUX.1[dev] HyperFLUX (8 steps)")
        turbo = gr.Image(label="FLUX.1[dev]-Turbo-Alpha (8 steps)")
    
    gr.Examples(
        examples=[
            ["the spirit of a Tamagotchi wandering in the city of Vienna"],
            ["a photo of a lavender cat"],
            ["a tiny astronaut hatching from an egg on the moon"],
            ["a delicious ceviche cheesecake slice"],
            ["an insect robot preparing a delicious meal"],
            ["a Charmander fine dining with a view to la Sagrada Família"]],
        fn=run_parallel_models,
        inputs=[prompt],
        outputs=[schnell, hyper, turbo],
        cache_examples="lazy"
    )
    
    gr.on(
        triggers=[submit.click, prompt.submit],
        fn=run_parallel_models,
        inputs=[prompt],
        outputs=[schnell, hyper, turbo]
    )
demo.launch()