multimodalart's picture
Update app.py
435c235 verified
raw
history blame contribute delete
No virus
3.21 kB
import gradio as gr
from diffusers import DiffusionPipeline
import spaces
import torch
from concurrent.futures import ProcessPoolExecutor
from huggingface_hub import hf_hub_download
dev_model = "black-forest-labs/FLUX.1-dev"
schnell_model = "black-forest-labs/FLUX.1-schnell"
device = "cuda" if torch.cuda.is_available() else "cpu"
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
hyper_lora = hf_hub_download(repo_name, ckpt_name)
repo_name = "alimama-creative/FLUX.1-Turbo-Alpha"
ckpt_name = "diffusion_pytorch_model.safetensors"
turbo_lora = hf_hub_download(repo_name, ckpt_name)
pipe_dev = DiffusionPipeline.from_pretrained(dev_model, torch_dtype=torch.bfloat16).to("cuda")
pipe_schnell = DiffusionPipeline.from_pretrained(
schnell_model,
text_encoder=pipe_dev.text_encoder,
text_encoder_2=pipe_dev.text_encoder_2,
tokenizer=pipe_dev.tokenizer,
tokenizer_2=pipe_dev.tokenizer_2,
torch_dtype=torch.bfloat16
)
@spaces.GPU(duration=75)
def run_parallel_models(prompt, progress=gr.Progress(track_tqdm=True)):
pipe_dev.load_lora_weights(hyper_lora)
image = pipe_dev(prompt, num_inference_steps=8, joint_attention_kwargs={"scale": 0.125}).images[0]
pipe_dev.unload_lora_weights()
yield image, gr.update(), gr.update()
pipe_dev.load_lora_weights(turbo_lora)
image = pipe_dev(prompt, num_inference_steps=8).images[0]
yield gr.update(), image, gr.update()
pipe_dev.unload_lora_weights()
pipe_dev.to("cpu")
pipe_schnell.to("cuda")
image = pipe_schnell(prompt, num_inference_steps=4).images[0]
yield gr.update(), gr.update(), image
#run_parallel_models.zerogpu = True
css = '''
#gen_btn{height: 100%}
#gen_column{align-self: stretch}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# Low Step Flux Comparison")
gr.Markdown("Compare the quality (not the speed) of FLUX Schnell (4 steps), FLUX.1[dev] HyperFLUX (8 steps), FLUX.1[dev]-Turbo-Alpha (8 steps). It runs a bit slow as it's inferencing the three models.")
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="Prompt")
with gr.Column(scale=1, min_width=120, elem_id="gen_column"):
submit = gr.Button("Run", elem_id="gen_btn")
with gr.Row():
hyper = gr.Image(label="FLUX.1[dev] HyperFLUX (8 steps)")
turbo = gr.Image(label="FLUX.1[dev]-Turbo-Alpha (8 steps)")
schnell = gr.Image(label="FLUX Schnell (4 steps)")
gr.Examples(
examples=[
["the spirit of a Tamagotchi wandering in the city of Vienna"],
["a photo of a lavender cat"],
["a tiny astronaut hatching from an egg on the moon"],
["a delicious ceviche cheesecake slice"],
["an insect robot preparing a delicious meal"],
["a Charmander fine dining with a view to la Sagrada Família"]],
fn=run_parallel_models,
inputs=[prompt],
outputs=[schnell, hyper, turbo],
cache_examples="lazy"
)
gr.on(
triggers=[submit.click, prompt.submit],
fn=run_parallel_models,
inputs=[prompt],
outputs=[hyper, turbo, schnell]
)
demo.launch()