multimodalart's picture
Update app.py
21df05c verified
raw
history blame contribute delete
No virus
4.23 kB
import gradio as gr
import spaces
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
### SDXL Turbo ####
pipe_turbo = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/sdxl-turbo",
vae=vae,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
### SDXL Lightning ###
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(
base,
unet=unet,
vae=vae,
text_encoder=pipe_turbo.text_encoder,
text_encoder_2=pipe_turbo.text_encoder_2,
tokenizer=pipe_turbo.tokenizer,
tokenizer_2=pipe_turbo.tokenizer_2,
torch_dtype=torch.float16,
variant="fp16"
)#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
pipe_lightning.to("cuda")
### Hyper SDXL ###
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(
base,
unet=unet,
vae=vae,
text_encoder=pipe_turbo.text_encoder,
text_encoder_2=pipe_turbo.text_encoder_2,
tokenizer=pipe_turbo.tokenizer,
tokenizer_2=pipe_turbo.tokenizer_2,
torch_dtype=torch.float16,
variant="fp16"
)#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
pipe_hyper.to("cuda")
del unet
@spaces.GPU
def run_comparison(prompt, progress=gr.Progress(track_tqdm=True)):
image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
yield image_turbo, None, None
image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
yield image_turbo, image_lightning, None
image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
yield image_turbo, image_lightning, image_hyper
examples = [
"A dignified beaver wearing glasses, a vest, and colorful neck tie.",
"The spirit of a tamagotchi wandering in the city of Barcelona",
"an ornate, high-backed mahogany chair with a red cushion",
"a sketch of a camel next to a stream",
"a delicate porcelain teacup sits on a saucer, its surface adorned with intricate blue patterns",
"a baby swan grafitti",
"A bald eagle made of chocolate powder, mango, and whipped cream"
]
with gr.Blocks() as demo:
gr.Markdown("## One step SDXL comparison 🦶")
gr.Markdown('Compare SDXL variants and distillations able to generate images in a single diffusion step')
prompt = gr.Textbox(label="Prompt")
run = gr.Button("Run")
with gr.Row():
with gr.Column():
image_turbo = gr.Image(label="SDXL Turbo")
gr.Markdown("## [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo)")
with gr.Column():
image_lightning = gr.Image(label="SDXL Lightning")
gr.Markdown("## [SDXL Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)")
with gr.Column():
image_hyper = gr.Image(label="Hyper SDXL")
gr.Markdown("## [Hyper SDXL](https://huggingface.co/ByteDance/Hyper-SD)")
image_outputs = [image_turbo, image_lightning, image_hyper]
gr.on(
triggers=[prompt.submit, run.click],
fn=run_comparison,
inputs=prompt,
outputs=image_outputs
)
gr.Examples(
examples=examples,
fn=run_comparison,
inputs=prompt,
outputs=image_outputs,
cache_examples=False,
run_on_click=True
)
demo.launch()