TDD / app.py
Ldhlwh's picture
Update app.py
61ce549 verified
raw
history blame contribute delete
No virus
7.73 kB
import spaces
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
from tdd_scheduler import TDDScheduler
from safetensors.torch import load_file
from PIL import Image
import transformers
transformers.utils.move_cache()
SAFETY_CHECKER = True
loaded_acc = None
device = "cuda"
ACC_lora={
"TDD":"sdxl_tdd_wo_adv_lora.safetensors",
"TDD_adv":"sdxl_tdd_lora_weights.safetensors",
}
if torch.cuda.is_available():
base1 = UNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
).to(device)
base2 = UNet2DConditionModel.from_pretrained(
"frankjoshua/realvisxlV40_v40Bakedvae", subfolder="unet", torch_dtype=torch.float16
).to(device)
pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=base1,
torch_dtype=torch.float16,
variant="fp16",
).to(device)
pipe_sdxl.load_lora_weights("RED-AIGC/TDD", weight_name=ACC_lora["TDD"], adapter_name="TDD")
pipe_sdxl.load_lora_weights("RED-AIGC/TDD", weight_name=ACC_lora["TDD_adv"], adapter_name="TDD_adv")
pipe_sdxl.scheduler = TDDScheduler.from_config(pipe_sdxl.scheduler.config)
pipe_sdxl_real = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=base2,
torch_dtype=torch.float16,
variant="fp16",
).to(device)
pipe_sdxl_real.load_lora_weights("RED-AIGC/TDD", weight_name=ACC_lora["TDD"], adapter_name="TDD")
pipe_sdxl_real.load_lora_weights("RED-AIGC/TDD", weight_name=ACC_lora["TDD_adv"], adapter_name="TDD_adv")
pipe_sdxl_real.scheduler = TDDScheduler.from_config(pipe_sdxl.scheduler.config)
def update_base_model(ckpt):
if torch.cuda.is_available():
pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
).to(device)
return pipe_sdxl
if SAFETY_CHECKER:
from safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(
images: list[Image.Image],
) -> tuple[list[Image.Image], list[bool]]:
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
has_nsfw_concepts = safety_checker(
images=[images], clip_input=safety_checker_input.pixel_values.to(device)
)
return images, has_nsfw_concepts
#@spaces.GPU(enable_queue=True, duration=5)
@spaces.GPU(enable_queue=True)
def generate_image(
prompt,
negative_prompt,
ckpt,
acc,
num_inference_steps,
guidance_scale,
eta,
seed,
progress=gr.Progress(track_tqdm=True),
):
global loaded_acc
#pipe = pipe_sdxl #if mode == "sdxl" else pipe_sd15
if ckpt == "Real":
pipe = pipe_sdxl_real
else:
pipe = pipe_sdxl
if loaded_acc != acc:
#pipe.load_lora_weights(ACC_lora[acc], adapter_name=acc)
pipe.set_adapters([acc], adapter_weights=[1.0])
print(pipe.get_active_adapters())
loaded_acc = acc
results = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
eta=eta,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
)
if SAFETY_CHECKER:
images, has_nsfw_concepts = check_nsfw_images(results.images)
if any(has_nsfw_concepts):
gr.Warning("NSFW content detected.")
return Image.new("RGB", (512, 512))
return images[0]
return results.images[0]
css = """
h1 {
text-align: center;
display:block;
}
.gradio-container {
max-width: 70.5rem !important;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# ✨Target-Driven Distillation✨
Target-Driven Distillation (TDD) is a state-of-the-art consistency distillation model that largely accelerates the inference processes of diffusion models. Using its delicate strategies of *target timestep selection* and *decoupled guidance*, models distilled by TDD can generated highly detailed images with only a few steps.
[**Project Page**](https://redaigc.github.io/TDD/) **|** [**Paper**](https://arxiv.org/abs/2409.01347) **|** [**Code**](https://github.com/RedAIGC/Target-Driven-Distillation) **|** [**Model**](https://huggingface.co/RED-AIGC/TDD) **|** [🤗 **TDD-SDXL Demo**](https://huggingface.co/spaces/RED-AIGC/TDD) **|** [🤗 **TDD-SVD Demo**](https://huggingface.co/spaces/RED-AIGC/SVD-TDD)
"""
)
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label="Prompt")
with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt")
with gr.Row():
steps = gr.Slider(
label="Sampling Steps",
minimum=4,
maximum=8,
step=1,
value=4,
)
with gr.Row():
guidance_scale = gr.Slider(
label="CFG Scale",
minimum=1,
maximum=4,
step=0.1,
value=2.0,
)
with gr.Row():
eta = gr.Slider(
label="eta",
minimum=0,
maximum=0.3,
step=0.01,
value=0.2,
)
with gr.Row():
seed = gr.Number(label="Seed", value=-1)
with gr.Row():
ckpt = gr.Dropdown(
label="Base Model",
choices=["SDXL-1.0", "Real"],
value="SDXL-1.0",
)
acc = gr.Dropdown(
label="Accelerate Lora",
choices=["TDD", "TDD_adv"],
value="TDD_adv",
)
with gr.Column(scale=1):
with gr.Group():
img = gr.Image(label="TDD Image", value="cat.png")
submit_sdxl = gr.Button("Run on SDXL")
gr.Examples(
examples=[
["A photo of a cat made of water.", "", "SDXL-1.0", "TDD_adv", 4, 1.7, 0.2, 546237],
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "NSFW, low quality, low resolution, normal quality", "Real", "TDD_adv", 7, 2.0, 0.3, 68436],
["panda mad scientist mixing sparkling chemicals", "", "Real", "TDD_adv", 5, 1.7, 0.2, 3853],
["a dog in snow, 8k", "", "SDXL-1.0", "TDD", 6, 1.3, 0.2, 12138]
],
inputs=[prompt, negative_prompt, ckpt, acc, steps, guidance_scale, eta, seed],
outputs=[img],
fn=generate_image,
cache_examples="lazy",
)
gr.on(
fn=generate_image,
triggers=[ckpt.change, prompt.submit, submit_sdxl.click],
inputs=[prompt, negative_prompt, ckpt, acc, steps, guidance_scale, eta, seed],
outputs=[img],
)
demo.queue(api_open=False).launch(show_api=False)