File size: 4,182 Bytes
5376cf0
847481c
600e217
847481c
 
0ab8f18
 
847481c
 
 
5376cf0
600e217
 
 
 
 
 
 
 
0ab8f18
847481c
 
 
 
 
 
 
 
0ab8f18
600e217
 
 
 
 
 
 
 
 
 
 
847481c
 
 
 
0ab8f18
847481c
 
64ab7c4
847481c
 
 
64ab7c4
847481c
 
 
 
600e217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af997b5
 
 
75b52ac
 
af997b5
 
 
75b52ac
 
 
600e217
75b52ac
 
 
 
af997b5
 
 
 
 
 
 
 
 
75b52ac
 
 
 
 
 
 
 
 
 
 
af997b5
 
 
 
 
847481c
75b52ac
 
 
 
 
 
 
600e217
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import spaces
from diffusers import StableDiffusionXLPipeline
from diffusers import DiffusionPipeline
from pydantic import BaseModel
from PIL import Image
import gradio as gr
import torch
import uuid
import io
import os

# Load the base & refiner pipelines
base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", 
    torch_dtype=torch.float16, 
    variant="fp16", 
    use_safetensors=True
)
base.to("cuda:0")

# Load your model
pipe = StableDiffusionXLPipeline.from_pretrained(
    "segmind/SSD-1B", 
    torch_dtype=torch.float16, 
    use_safetensors=True, 
    variant="fp16"
)
pipe.to("cuda:0")

refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
)
refiner.to("cuda:0")
refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)

@spaces.GPU  # Apply the GPU decorator
def generate_and_save_image(prompt, negative_prompt=''):
    # Generate image using the provided prompts
    image = pipe(prompt=prompt, negative_prompt=negative_prompt).images[0]

    # Generate a unique UUID for the filename
    unique_id = str(uuid.uuid4())
    image_path = f"generated_images/{unique_id}.jpeg"

    # Save generated image locally
    os.makedirs('generated_images', exist_ok=True)
    image.save(image_path, format='JPEG')

    # Return the path of the saved image to display in Gradio interface
    return image_path

def generate_image_with_refinement(prompt):
    n_steps = 40
    high_noise_frac = 0.8

    # run both experts
    image = base(
        prompt=prompt,
        num_inference_steps=n_steps,
        denoising_end=high_noise_frac,
        output_type="latent",
    ).images
    image = refiner(
        prompt=prompt,
        num_inference_steps=n_steps,
        denoising_start=high_noise_frac,
        image=image,
    ).images[0]

    # Save the image as before
    unique_id = str(uuid.uuid4())
    image_path = f"generated_images_refined/{unique_id}.jpeg"
    os.makedirs('generated_images_refined', exist_ok=True)
    image.save(image_path, format='JPEG')

    return image_path

# Start of the Gradio Blocks interface
with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Image Generation with SSD-1B")
        gr.Markdown("Enter a prompt and (optionally) a negative prompt to generate an image.")
        
        # Input fields for positive and negative prompts
        with gr.Row():
            prompt1 = gr.Textbox(label="Enter prompt")
            negative_prompt = gr.Textbox(label="Enter negative prompt (optional)")
        
        # Button for generating the image
        generate_button1 = gr.Button("Generate Image")

        # Output image display, set to a larger default size
        output_image1 = gr.Image(label="Generated Image")

        # Click event for the generate button
        generate_button1.click(
            generate_and_save_image, 
            inputs=[prompt1, negative_prompt], 
            outputs=output_image1
        )

    with gr.Column():
        gr.Markdown("## Refined Image Generation")
        gr.Markdown("Enter a prompt to generate a refined image.")
        
        # Input field for the prompt
        prompt2 = gr.Textbox(label="Enter prompt for refined generation")
        
        # Button for generating the refined image
        generate_button2 = gr.Button("Generate Refined Image")
        
        # Output refined image display, set to a larger default size
        output_image2 = gr.Image(label="Generated Refined Image")

        # Click event for the generate button
        generate_button2.click(
            generate_image_with_refinement, 
            inputs=[prompt2], 
            outputs=output_image2
        )

    # Set the image display to be the largest element for both SSD-1B and refined generation
    demo.update(
        output_image1.style(width='100%', height='auto', min_height='400px'),
        output_image2.style(width='100%', height='auto', min_height='400px')
    )


# Launch the combined Gradio app
demo.launch()