File size: 5,553 Bytes
c77781d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd7bda3
c77781d
 
 
 
 
 
dd7bda3
c77781d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ae18c8
dd7bda3
8ae18c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# import gradio as gr
# import torch
# import spaces
# from diffusers import FluxPipeline, DiffusionPipeline
# from torchao.quantization import autoquant



# # # # normal FluxPipeline
# # pipeline_normal = FluxPipeline.from_pretrained(
# #     "sayakpaul/FLUX.1-merged",
# #     torch_dtype=torch.bfloat16
# # ).to("cuda")
# # pipeline_normal.transformer.to(memory_format=torch.channels_last)
# # pipeline_normal.transformer = torch.compile(pipeline_normal.transformer, mode="max-autotune", fullgraph=True)

# pipeline_normal = DiffusionPipeline.from_pretrained("sayakpaul/FLUX.1-merged")
# pipeline_normal.enable_model_cpu_offload()
# pipeline_normal.load_lora_weights("DarkMoonDragon/TurboRender-flux-dev")
# # # optimized FluxPipeline
# # pipeline_optimized = FluxPipeline.from_pretrained(
# #     "camenduru/FLUX.1-dev-diffusers",
# #     torch_dtype=torch.bfloat16
# # ).to("cuda")
# # pipeline_optimized.transformer.to(memory_format=torch.channels_last)
# # pipeline_optimized.transformer = torch.compile(
# #     pipeline_optimized.transformer,
# #     mode="max-autotune",
# #     fullgraph=True
# # )
# # # wrap the autoquant call in a try-except block to handle unsupported layers
# # for name, layer in pipeline_optimized.transformer.named_children():
# #     try:
# #         # apply autoquant to each layer
# #         pipeline_optimized.transformer._modules[name] = autoquant(layer, error_on_unseen=False)
# #         print(f"Successfully quantized {name}")
# #     except AttributeError as e:
# #         print(f"Skipping layer {name} due to error: {e}")
# #     except Exception as e:
# #         print(f"Unexpected error while quantizing {name}: {e}")

# # pipeline_optimized.transformer = autoquant(
# #     pipeline_optimized.transformer,
# #     error_on_unseen=False
# # )
# pipeline_optimized = pipeline_normal

# @spaces.GPU(duration=120)
# def generate_images(prompt, guidance_scale, num_inference_steps):
#     # # generate image with normal pipeline
#     # image_normal = pipeline_normal(
#     #     prompt=prompt,
#     #     guidance_scale=guidance_scale,
#     #     num_inference_steps=int(num_inference_steps)
#     # ).images[0]
    
#     # generate image with optimized pipeline
#     image_optimized = pipeline_optimized(
#         prompt=prompt,
#         guidance_scale=guidance_scale,
#         num_inference_steps=int(num_inference_steps)
#     ).images[0]
    
#     return image_optimized

# # set up Gradio interface
# demo = gr.Interface(
#     fn=generate_images,
#     inputs=[
#         gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
#         gr.Slider(1.0, 10.0, step=0.5, value=3.5, label="Guidance Scale"),
#         gr.Slider(10, 100, step=1, value=50, label="Number of Inference Steps")
#     ],
#     outputs=[
#         gr.Image(type="pil", label="Optimized FluxPipeline")
#     ],
#     title="FluxPipeline Comparison",
#     description="Compare images generated by the normal FluxPipeline and the optimized one using torchao and torch.compile()."
# )

# demo.launch()
import gradio as gr
import torch
from diffusers import FluxPipeline
from torchao import swap_conv2d_1x1_to_linear, apply_dynamic_quant

# Step 1: Enable PyTorch 2-specific optimizations
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# Step 2: Load the Flux pipeline with bfloat16 precision
pipe = FluxPipeline.from_pretrained(
    "sayakpaul/FLUX.1-merged",
    torch_dtype=torch.bfloat16
).to("cuda")

# Step 3: Apply attention optimizations
pipe.fuse_qkv_projections()

# Step 4: Change memory layout for performance boost
pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

# Step 5: Swap Conv2D 1x1 layers to Linear and apply dynamic quantization
def dynamic_quant_filter_fn(mod, *args):
    return isinstance(mod, torch.nn.Linear) and mod.in_features > 16

def conv_filter_fn(mod, *args):
    return isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1)

swap_conv2d_1x1_to_linear(pipe.unet, conv_filter_fn)
swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)

apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)

# Step 6: Compile the UNet and VAE for optimized kernels
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

# Image generation function
def generate_image(prompt, guidance_scale, num_inference_steps):
    # Generate the image using the optimized pipeline
    image = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images[0]
    return image

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Optimized Flux Model Inference")

    with gr.Row():
        prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
        guidance_scale = gr.Slider(0.0, 15.0, value=7.5, step=0.1, label="Guidance Scale")
        steps = gr.Slider(5, 50, value=30, step=1, label="Inference Steps")

    image_output = gr.Image(type="pil", label="Generated Image")

    generate_button = gr.Button("Generate Image")
    generate_button.click(generate_image, inputs=[prompt, guidance_scale, steps], outputs=image_output)

demo.launch()