File size: 4,117 Bytes
faf0564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d52f0e
 
 
faf0564
 
 
 
 
 
 
 
 
 
 
bbf54e6
faf0564
 
 
 
 
f4b07c0
 
 
 
faf0564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
from PIL import Image
from diffusers.utils import numpy_to_pil
from diffusers import (
    T2IAdapter,
    StableDiffusionXLAdapterPipeline,
    AutoencoderKL,
    EulerAncestralDiscreteScheduler
)
from controlnet_aux import PidiNetDetector

# Global variable to store the pipeline
pipe = None

def load_pipe():
    global pipe
    if pipe is None:
        model_id = "stabilityai/stable-diffusion-xl-base-1.0"
        adapter = T2IAdapter.from_pretrained(
            "Adapter/t2iadapter", 
            subfolder="sketch_sdxl_1.0", 
            torch_dtype=torch.float16, 
            adapter_type="full_adapter_xl")
        euler_a = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
        vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

        pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
            model_id, 
            adapter=adapter, 
            vae=vae,
            scheduler=euler_a,
            torch_dtype=torch.float16, 
            variant="fp16", 
        ).to("cuda")
        pipe.enable_xformers_memory_efficient_attention()

def preprocess_image(uploaded_file):
    if uploaded_file is None:
        return None, "Please upload an image."
    img_upload = Image.open(uploaded_file)
    preprocessor = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
    img_preprocessed = preprocessor(
        img_upload, 
        detect_resolution=1024, 
        image_resolution=1024,
        apply_filter=True).convert("L")
    return img_preprocessed, ""

def generate(prompt, uploaded_file, prompt_addition, negative_prompt, num_images, num_steps, guidance_scale, adapter_conditioning_scale, adapter_conditioning_factor):
    global pipe
    load_pipe()  # Ensure the model is loaded
    img_preprocessed, error_message = preprocess_image(uploaded_file)
    if error_message:
        return error_message

    num_images = int(num_images)
    
    params = {
        "image": img_preprocessed,
        "num_inference_steps": num_steps,
        "prompt": f"{prompt},{prompt_addition}" if prompt_addition.strip() else prompt,
        "negative_prompt": negative_prompt,
        "guidance_scale": guidance_scale,
        "adapter_conditioning_scale": adapter_conditioning_scale / 100,
        "adapter_conditioning_factor": adapter_conditioning_factor / 100,
        "num_images_per_prompt": num_images
    }
    generated_images = pipe(**params).images
    return generated_images  # Returning PIL images directly

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", value="a robot elephant", placeholder="Enter a description for the image you want to generate")
            prompt_addition = gr.Textbox(label="Prompt addition", value="in outer space, in the style of picasso, highly detailed")
            negative_prompt = gr.Textbox(label="Negative prompt", value="disfigured, wrong number of digits, cropped, low quality")
            num_images = gr.Slider(minimum=1, maximum=10, value=1, label="Number of images to generate")
            num_steps = gr.Slider(minimum=1, maximum=100, value=20, label="Number of steps")
            guidance_scale = gr.Slider(minimum=6, maximum=10, value=7, label="Guidance scale")
            adapter_conditioning_scale = gr.Slider(minimum=0, maximum=100, value=90, label="Adapter conditioning scale")
            adapter_conditioning_factor = gr.Slider(minimum=0, maximum=100, value=90, label="Adapter conditioning factor")
            uploaded_file = gr.File(label="Upload image", type='filepath') 

        with gr.Column():
            output_gallery = gr.Gallery(label="Generated images")
            generate_button = gr.Button("Generate")
            generate_button.click(
                generate,
                inputs=[prompt, uploaded_file, prompt_addition, negative_prompt, num_images, num_steps, guidance_scale, adapter_conditioning_scale, adapter_conditioning_factor],
                outputs=[output_gallery]
            )

demo.launch()