File size: 2,555 Bytes
1887c49
a3e9651
 
9b8f82d
bf459df
c3a473c
a3e9651
 
 
 
 
 
1fe5ec5
ff4f40d
 
 
 
 
 
 
1fe5ec5
a3e9651
 
c3a473c
1fe5ec5
a3e9651
 
 
 
 
 
 
 
 
c3a473c
1fe5ec5
a3e9651
 
 
 
 
 
 
 
 
c3a473c
1fe5ec5
a3e9651
 
6631477
 
a3e9651
 
 
 
 
6aab45f
 
 
 
6631477
a3e9651
 
6631477
a3e9651
6631477
a3e9651
6631477
 
 
 
ad96459
6631477
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from diffusers import DiffusionPipeline
import spaces
import torch
from concurrent.futures import ProcessPoolExecutor
from huggingface_hub import hf_hub_download

dev_model = "black-forest-labs/FLUX.1-dev"
schnell_model = "black-forest-labs/FLUX.1-schnell"

device = "cuda" if torch.cuda.is_available() else "cpu"

pipe_dev = DiffusionPipeline.from_pretrained(dev_model, torch_dtype=torch.bfloat16)
pipe_schnell = DiffusionPipeline.from_pretrained(
    schnell_model,
    text_encoder=pipe_dev.text_encoder,
    text_encoder_2=pipe_dev.text_encoder_2,
    tokenizer=pipe_dev.tokenizer,
    tokenizer_2=pipe_dev.tokenizer_2,
    torch_dtype=torch.bfloat16
)
@spaces.GPU
def run_dev_hyper(prompt):
    print("dev_hyper")
    pipe_dev.to("cuda")
    repo_name = "ByteDance/Hyper-SD"
    ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
    pipe_dev.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
    image = pipe_dev(prompt, num_inference_steps=8, joint_attention_kwargs={"scale": 0.125}).images[0]
    pipe_dev.unload_lora_weights()
    return image

@spaces.GPU
def run_dev_turbo(prompt):
    print("dev_turbo")
    pipe_dev.to("cuda")
    repo_name = "alimama-creative/FLUX.1-Turbo-Alpha"
    ckpt_name = "diffusion_pytorch_model.safetensors"
    pipe_dev.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
    image = pipe_dev(prompt, num_inference_steps=8).images[0]
    pipe_dev.unload_lora_weights()
    return image

@spaces.GPU
def run_schnell(prompt):
    print("schnell")
    pipe_schnell.to("cuda")
    image = pipe_schnell(prompt).images[0]
    return image

def run_parallel_models(prompt):
    
    with ProcessPoolExecutor(3) as e:
        image_dev_hyper = run_dev_hyper(prompt)
        image_dev_turbo = run_dev_turbo(prompt)
        image_schnell = run_schnell(prompt)
    res_dev_hyper = image_dev_hyper.result()
    res_dev_turbo = image_dev_turbo.result()
    res_schnell = image_schnell.result()
    return res_dev_hyper, res_dev_turbo, res_schnell

run_parallel_models.zerogpu = True

with gr.Blocks() as demo:
    gr.Markdown("# Fast Flux Comparison")
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
        submit = gr.Button()
    with gr.Row():
        schnell = gr.Image(label="FLUX Schnell (4 steps)")
        hyper = gr.Image(label="FLUX.1[dev] HyperFLUX (8 steps)")
        turbo = gr.Image(label="FLUX.1[dev]-Turbo-Alpha (8 steps)")

    submit.click(
        fn=run_parallel_models,
        inputs=[prompt],
        outputs=[schnell, hyper, turbo]
    )
demo.launch()