Xu Cao commited on
Commit
4362f0a
1 Parent(s): 2fa6db8

update demo

Browse files
README.md CHANGED
@@ -1,3 +1,62 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable Diffusion 3 Inpaint Pipeline
2
+
3
+ | input image | input mask image | output |
4
+ |:-------------------------:|:-------------------------:|:-------------------------:|
5
+ |<img src="./overture-creations-5sI6fQgYIuo.png" width = "400" /> | <img src="./overture-creations-5sI6fQgYIuo_mask.png" width = "400" /> | <img src="./overture-creations-5sI6fQgYIuo_output.jpg" width = "400" /> |
6
+
7
+ **Please ensure that the version of diffusers >= 0.29.1**
8
+
9
+ # Demo
10
+ ```python
11
+ import torch
12
+ from torchvision import transforms
13
+
14
+ from diffusers import StableDiffusion3InpaintPipeline
15
+ from diffusers.utils import load_image
16
+
17
+ def preprocess_image(image):
18
+ image = image.convert("RGB")
19
+ image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
20
+ image = transforms.ToTensor()(image)
21
+ image = image * 2 - 1
22
+ image = image.unsqueeze(0).to("cuda")
23
+ return image
24
+
25
+ def preprocess_mask(mask):
26
+ mask = mask.convert("L")
27
+ mask = transforms.CenterCrop((mask.size[1] // 64 * 64, mask.size[0] // 64 * 64))(mask)
28
+ mask = transforms.ToTensor()(mask)
29
+ mask = mask.to("cuda")
30
+ return mask
31
+
32
+ pipe = StableDiffusion3InpaintPipeline.from_pretrained(
33
+ "stabilityai/stable-diffusion-3-medium-diffusers",
34
+ torch_dtype=torch.float16,
35
+ ).to("cuda")
36
+
37
+ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
38
+ source_image = load_image(
39
+ "./overture-creations-5sI6fQgYIuo.png"
40
+ )
41
+ source = preprocess_image(source_image)
42
+ mask = preprocess_mask(
43
+ load_image(
44
+ "./overture-creations-5sI6fQgYIuo_mask.png"
45
+ )
46
+ )
47
+
48
+ image = pipe(
49
+ prompt=prompt,
50
+ image=source,
51
+ mask_image=1-mask,
52
+ height=1024,
53
+ width=1024,
54
+ num_inference_steps=28,
55
+ guidance_scale=7.0,
56
+ strength=0.6,
57
+ ).images[0]
58
+
59
+ image.save("output.png")
60
+ ```
61
+
62
+
demo.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+
4
+ from pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline
5
+ from diffusers.utils import load_image
6
+
7
+ def preprocess_image(image):
8
+ image = image.convert("RGB")
9
+ image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
10
+ image = transforms.ToTensor()(image)
11
+ image = image * 2 - 1
12
+ image = image.unsqueeze(0).to("cuda")
13
+ return image
14
+
15
+ def preprocess_mask(mask):
16
+ mask = mask.convert("L")
17
+ mask = transforms.CenterCrop((mask.size[1] // 64 * 64, mask.size[0] // 64 * 64))(mask)
18
+ mask = transforms.ToTensor()(mask)
19
+ mask = mask.to("cuda")
20
+ return mask
21
+
22
+ pipe = StableDiffusion3InpaintPipeline.from_pretrained(
23
+ "stabilityai/stable-diffusion-3-medium-diffusers",
24
+ torch_dtype=torch.float16,
25
+ ).to("cuda")
26
+
27
+ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
28
+ source_image = load_image(
29
+ "./overture-creations-5sI6fQgYIuo.png"
30
+ )
31
+ source = preprocess_image(source_image)
32
+ mask = preprocess_mask(
33
+ load_image(
34
+ "./overture-creations-5sI6fQgYIuo_mask.png"
35
+ )
36
+ )
37
+
38
+ image = pipe(
39
+ prompt=prompt,
40
+ image=source,
41
+ mask_image=1-mask,
42
+ height=1024,
43
+ width=1024,
44
+ num_inference_steps=28,
45
+ guidance_scale=7.0,
46
+ strength=0.6,
47
+ ).images[0]
48
+
49
+ image.save("overture-creations-5sI6fQgYIuo_output.jpg")
overture-creations-5sI6fQgYIuo.png ADDED
overture-creations-5sI6fQgYIuo_mask.png ADDED
overture-creations-5sI6fQgYIuo_output.jpg ADDED
pipeline_stable_diffusion_3_inpaint.py ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ import inspect
14
+ from typing import Callable, Dict, List, Optional, Union
15
+
16
+ import PIL.Image
17
+ import torch
18
+ from transformers import (
19
+ CLIPTextModelWithProjection,
20
+ CLIPTokenizer,
21
+ T5EncoderModel,
22
+ T5TokenizerFast,
23
+ )
24
+
25
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
26
+ from diffusers.models.autoencoders import AutoencoderKL
27
+ from diffusers.models.transformers import SD3Transformer2DModel
28
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
29
+ from diffusers.utils import (
30
+ is_torch_xla_available,
31
+ logging,
32
+ replace_example_docstring,
33
+ )
34
+ from diffusers.utils.torch_utils import randn_tensor
35
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
36
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
37
+
38
+
39
+ if is_torch_xla_available():
40
+ import torch_xla.core.xla_model as xm
41
+
42
+ XLA_AVAILABLE = True
43
+ else:
44
+ XLA_AVAILABLE = False
45
+
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+ EXAMPLE_DOC_STRING = """
50
+ Examples:
51
+ ```py
52
+ >>> import torch
53
+
54
+ >>> from diffusers import AutoPipelineForImage2Image
55
+ >>> from diffusers.utils import load_image
56
+
57
+ >>> device = "cuda"
58
+ >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
59
+ >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
60
+ >>> pipe = pipe.to(device)
61
+
62
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
63
+ >>> init_image = load_image(url).resize((512, 512))
64
+
65
+ >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
66
+
67
+ >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]
68
+ ```
69
+ """
70
+
71
+
72
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
73
+ def retrieve_latents(
74
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
75
+ ):
76
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
77
+ return encoder_output.latent_dist.sample(generator)
78
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
79
+ return encoder_output.latent_dist.mode()
80
+ elif hasattr(encoder_output, "latents"):
81
+ return encoder_output.latents
82
+ else:
83
+ raise AttributeError("Could not access latents of provided encoder_output")
84
+
85
+
86
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
87
+ def retrieve_timesteps(
88
+ scheduler,
89
+ num_inference_steps: Optional[int] = None,
90
+ device: Optional[Union[str, torch.device]] = None,
91
+ timesteps: Optional[List[int]] = None,
92
+ sigmas: Optional[List[float]] = None,
93
+ **kwargs,
94
+ ):
95
+ """
96
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
97
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
98
+
99
+ Args:
100
+ scheduler (`SchedulerMixin`):
101
+ The scheduler to get timesteps from.
102
+ num_inference_steps (`int`):
103
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
104
+ must be `None`.
105
+ device (`str` or `torch.device`, *optional*):
106
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
107
+ timesteps (`List[int]`, *optional*):
108
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
109
+ `num_inference_steps` and `sigmas` must be `None`.
110
+ sigmas (`List[float]`, *optional*):
111
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
112
+ `num_inference_steps` and `timesteps` must be `None`.
113
+
114
+ Returns:
115
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
116
+ second element is the number of inference steps.
117
+ """
118
+ if timesteps is not None and sigmas is not None:
119
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
120
+ if timesteps is not None:
121
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
122
+ if not accepts_timesteps:
123
+ raise ValueError(
124
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
125
+ f" timestep schedules. Please check whether you are using the correct scheduler."
126
+ )
127
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
128
+ timesteps = scheduler.timesteps
129
+ num_inference_steps = len(timesteps)
130
+ elif sigmas is not None:
131
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
132
+ if not accept_sigmas:
133
+ raise ValueError(
134
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
135
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
136
+ )
137
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
138
+ timesteps = scheduler.timesteps
139
+ num_inference_steps = len(timesteps)
140
+ else:
141
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ return timesteps, num_inference_steps
144
+
145
+
146
+ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
147
+ r"""
148
+ Args:
149
+ transformer ([`SD3Transformer2DModel`]):
150
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
151
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
152
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
153
+ vae ([`AutoencoderKL`]):
154
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
155
+ text_encoder ([`CLIPTextModelWithProjection`]):
156
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
157
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
158
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
159
+ as its dimension.
160
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
161
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
162
+ specifically the
163
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
164
+ variant.
165
+ text_encoder_3 ([`T5EncoderModel`]):
166
+ Frozen text-encoder. Stable Diffusion 3 uses
167
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
168
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
169
+ tokenizer (`CLIPTokenizer`):
170
+ Tokenizer of class
171
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
172
+ tokenizer_2 (`CLIPTokenizer`):
173
+ Second Tokenizer of class
174
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
175
+ tokenizer_3 (`T5TokenizerFast`):
176
+ Tokenizer of class
177
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
178
+ """
179
+
180
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
181
+ _optional_components = []
182
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
183
+
184
+ def __init__(
185
+ self,
186
+ transformer: SD3Transformer2DModel,
187
+ scheduler: FlowMatchEulerDiscreteScheduler,
188
+ vae: AutoencoderKL,
189
+ text_encoder: CLIPTextModelWithProjection,
190
+ tokenizer: CLIPTokenizer,
191
+ text_encoder_2: CLIPTextModelWithProjection,
192
+ tokenizer_2: CLIPTokenizer,
193
+ text_encoder_3: T5EncoderModel,
194
+ tokenizer_3: T5TokenizerFast,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.register_modules(
199
+ vae=vae,
200
+ text_encoder=text_encoder,
201
+ text_encoder_2=text_encoder_2,
202
+ text_encoder_3=text_encoder_3,
203
+ tokenizer=tokenizer,
204
+ tokenizer_2=tokenizer_2,
205
+ tokenizer_3=tokenizer_3,
206
+ transformer=transformer,
207
+ scheduler=scheduler,
208
+ )
209
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
210
+ self.image_processor = VaeImageProcessor(
211
+ vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
212
+ )
213
+ self.mask_processor = VaeImageProcessor(
214
+ vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True
215
+ )
216
+ self.tokenizer_max_length = self.tokenizer.model_max_length
217
+ self.default_sample_size = self.transformer.config.sample_size
218
+
219
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
220
+ def _get_t5_prompt_embeds(
221
+ self,
222
+ prompt: Union[str, List[str]] = None,
223
+ num_images_per_prompt: int = 1,
224
+ max_sequence_length: int = 256,
225
+ device: Optional[torch.device] = None,
226
+ dtype: Optional[torch.dtype] = None,
227
+ ):
228
+ device = device or self._execution_device
229
+ dtype = dtype or self.text_encoder.dtype
230
+
231
+ prompt = [prompt] if isinstance(prompt, str) else prompt
232
+ batch_size = len(prompt)
233
+
234
+ if self.text_encoder_3 is None:
235
+ return torch.zeros(
236
+ (
237
+ batch_size * num_images_per_prompt,
238
+ self.tokenizer_max_length,
239
+ self.transformer.config.joint_attention_dim,
240
+ ),
241
+ device=device,
242
+ dtype=dtype,
243
+ )
244
+
245
+ text_inputs = self.tokenizer_3(
246
+ prompt,
247
+ padding="max_length",
248
+ max_length=max_sequence_length,
249
+ truncation=True,
250
+ add_special_tokens=True,
251
+ return_tensors="pt",
252
+ )
253
+ text_input_ids = text_inputs.input_ids
254
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
255
+
256
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
257
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
258
+ logger.warning(
259
+ "The following part of your input was truncated because `max_sequence_length` is set to "
260
+ f" {max_sequence_length} tokens: {removed_text}"
261
+ )
262
+
263
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
264
+
265
+ dtype = self.text_encoder_3.dtype
266
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
267
+
268
+ _, seq_len, _ = prompt_embeds.shape
269
+
270
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
271
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
272
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
273
+
274
+ return prompt_embeds
275
+
276
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
277
+ def _get_clip_prompt_embeds(
278
+ self,
279
+ prompt: Union[str, List[str]],
280
+ num_images_per_prompt: int = 1,
281
+ device: Optional[torch.device] = None,
282
+ clip_skip: Optional[int] = None,
283
+ clip_model_index: int = 0,
284
+ ):
285
+ device = device or self._execution_device
286
+
287
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
288
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
289
+
290
+ tokenizer = clip_tokenizers[clip_model_index]
291
+ text_encoder = clip_text_encoders[clip_model_index]
292
+
293
+ prompt = [prompt] if isinstance(prompt, str) else prompt
294
+ batch_size = len(prompt)
295
+
296
+ text_inputs = tokenizer(
297
+ prompt,
298
+ padding="max_length",
299
+ max_length=self.tokenizer_max_length,
300
+ truncation=True,
301
+ return_tensors="pt",
302
+ )
303
+
304
+ text_input_ids = text_inputs.input_ids
305
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
306
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
307
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
308
+ logger.warning(
309
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
310
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
311
+ )
312
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
313
+ pooled_prompt_embeds = prompt_embeds[0]
314
+
315
+ if clip_skip is None:
316
+ prompt_embeds = prompt_embeds.hidden_states[-2]
317
+ else:
318
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
319
+
320
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
321
+
322
+ _, seq_len, _ = prompt_embeds.shape
323
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
324
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
325
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
326
+
327
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
328
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
329
+
330
+ return prompt_embeds, pooled_prompt_embeds
331
+
332
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
333
+ def encode_prompt(
334
+ self,
335
+ prompt: Union[str, List[str]],
336
+ prompt_2: Union[str, List[str]],
337
+ prompt_3: Union[str, List[str]],
338
+ device: Optional[torch.device] = None,
339
+ num_images_per_prompt: int = 1,
340
+ do_classifier_free_guidance: bool = True,
341
+ negative_prompt: Optional[Union[str, List[str]]] = None,
342
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
343
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
344
+ prompt_embeds: Optional[torch.FloatTensor] = None,
345
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
346
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
347
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
348
+ clip_skip: Optional[int] = None,
349
+ max_sequence_length: int = 256,
350
+ ):
351
+ r"""
352
+
353
+ Args:
354
+ prompt (`str` or `List[str]`, *optional*):
355
+ prompt to be encoded
356
+ prompt_2 (`str` or `List[str]`, *optional*):
357
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
358
+ used in all text-encoders
359
+ prompt_3 (`str` or `List[str]`, *optional*):
360
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
361
+ used in all text-encoders
362
+ device: (`torch.device`):
363
+ torch device
364
+ num_images_per_prompt (`int`):
365
+ number of images that should be generated per prompt
366
+ do_classifier_free_guidance (`bool`):
367
+ whether to use classifier free guidance or not
368
+ negative_prompt (`str` or `List[str]`, *optional*):
369
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
370
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
371
+ less than `1`).
372
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
373
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
374
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
375
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
376
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
377
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
378
+ prompt_embeds (`torch.FloatTensor`, *optional*):
379
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
380
+ provided, text embeddings will be generated from `prompt` input argument.
381
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
382
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
383
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
384
+ argument.
385
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
386
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
387
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
388
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
389
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
390
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
391
+ input argument.
392
+ clip_skip (`int`, *optional*):
393
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
394
+ the output of the pre-final layer will be used for computing the prompt embeddings.
395
+ """
396
+ device = device or self._execution_device
397
+
398
+ prompt = [prompt] if isinstance(prompt, str) else prompt
399
+ if prompt is not None:
400
+ batch_size = len(prompt)
401
+ else:
402
+ batch_size = prompt_embeds.shape[0]
403
+
404
+ if prompt_embeds is None:
405
+ prompt_2 = prompt_2 or prompt
406
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
407
+
408
+ prompt_3 = prompt_3 or prompt
409
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
410
+
411
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
412
+ prompt=prompt,
413
+ device=device,
414
+ num_images_per_prompt=num_images_per_prompt,
415
+ clip_skip=clip_skip,
416
+ clip_model_index=0,
417
+ )
418
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
419
+ prompt=prompt_2,
420
+ device=device,
421
+ num_images_per_prompt=num_images_per_prompt,
422
+ clip_skip=clip_skip,
423
+ clip_model_index=1,
424
+ )
425
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
426
+
427
+ t5_prompt_embed = self._get_t5_prompt_embeds(
428
+ prompt=prompt_3,
429
+ num_images_per_prompt=num_images_per_prompt,
430
+ max_sequence_length=max_sequence_length,
431
+ device=device,
432
+ )
433
+
434
+ clip_prompt_embeds = torch.nn.functional.pad(
435
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
436
+ )
437
+
438
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
439
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
440
+
441
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
442
+ negative_prompt = negative_prompt or ""
443
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
444
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
445
+
446
+ # normalize str to list
447
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
448
+ negative_prompt_2 = (
449
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
450
+ )
451
+ negative_prompt_3 = (
452
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
453
+ )
454
+
455
+ if prompt is not None and type(prompt) is not type(negative_prompt):
456
+ raise TypeError(
457
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
458
+ f" {type(prompt)}."
459
+ )
460
+ elif batch_size != len(negative_prompt):
461
+ raise ValueError(
462
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
463
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
464
+ " the batch size of `prompt`."
465
+ )
466
+
467
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
468
+ negative_prompt,
469
+ device=device,
470
+ num_images_per_prompt=num_images_per_prompt,
471
+ clip_skip=None,
472
+ clip_model_index=0,
473
+ )
474
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
475
+ negative_prompt_2,
476
+ device=device,
477
+ num_images_per_prompt=num_images_per_prompt,
478
+ clip_skip=None,
479
+ clip_model_index=1,
480
+ )
481
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
482
+
483
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
484
+ prompt=negative_prompt_3,
485
+ num_images_per_prompt=num_images_per_prompt,
486
+ max_sequence_length=max_sequence_length,
487
+ device=device,
488
+ )
489
+
490
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
491
+ negative_clip_prompt_embeds,
492
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
493
+ )
494
+
495
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
496
+ negative_pooled_prompt_embeds = torch.cat(
497
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
498
+ )
499
+
500
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
501
+
502
+ def check_inputs(
503
+ self,
504
+ prompt,
505
+ prompt_2,
506
+ prompt_3,
507
+ strength,
508
+ negative_prompt=None,
509
+ negative_prompt_2=None,
510
+ negative_prompt_3=None,
511
+ prompt_embeds=None,
512
+ negative_prompt_embeds=None,
513
+ pooled_prompt_embeds=None,
514
+ negative_pooled_prompt_embeds=None,
515
+ callback_on_step_end_tensor_inputs=None,
516
+ max_sequence_length=None,
517
+ ):
518
+ if strength < 0 or strength > 1:
519
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
520
+
521
+ if callback_on_step_end_tensor_inputs is not None and not all(
522
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
523
+ ):
524
+ raise ValueError(
525
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
526
+ )
527
+
528
+ if prompt is not None and prompt_embeds is not None:
529
+ raise ValueError(
530
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
531
+ " only forward one of the two."
532
+ )
533
+ elif prompt_2 is not None and prompt_embeds is not None:
534
+ raise ValueError(
535
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
536
+ " only forward one of the two."
537
+ )
538
+ elif prompt_3 is not None and prompt_embeds is not None:
539
+ raise ValueError(
540
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
541
+ " only forward one of the two."
542
+ )
543
+ elif prompt is None and prompt_embeds is None:
544
+ raise ValueError(
545
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
546
+ )
547
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
548
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
549
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
550
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
551
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
552
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
553
+
554
+ if negative_prompt is not None and negative_prompt_embeds is not None:
555
+ raise ValueError(
556
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
557
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
558
+ )
559
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
560
+ raise ValueError(
561
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
562
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
563
+ )
564
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
565
+ raise ValueError(
566
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
567
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
568
+ )
569
+
570
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
571
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
572
+ raise ValueError(
573
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
574
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
575
+ f" {negative_prompt_embeds.shape}."
576
+ )
577
+
578
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
579
+ raise ValueError(
580
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
581
+ )
582
+
583
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
584
+ raise ValueError(
585
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
586
+ )
587
+
588
+ if max_sequence_length is not None and max_sequence_length > 512:
589
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
590
+
591
+ def get_timesteps(self, num_inference_steps, strength, device):
592
+ # get the original timestep using init_timestep
593
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
594
+
595
+ t_start = int(max(num_inference_steps - init_timestep, 0))
596
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
597
+ if hasattr(self.scheduler, "set_begin_index"):
598
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
599
+
600
+ return timesteps, num_inference_steps - t_start
601
+
602
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
603
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
604
+ raise ValueError(
605
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
606
+ )
607
+
608
+ image = image.to(device=device, dtype=dtype)
609
+
610
+ batch_size = batch_size * num_images_per_prompt
611
+ if image.shape[1] == self.vae.config.latent_channels:
612
+ init_latents = image
613
+
614
+ else:
615
+ if isinstance(generator, list) and len(generator) != batch_size:
616
+ raise ValueError(
617
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
618
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
619
+ )
620
+
621
+ elif isinstance(generator, list):
622
+ init_latents = [
623
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
624
+ for i in range(batch_size)
625
+ ]
626
+ init_latents = torch.cat(init_latents, dim=0)
627
+ else:
628
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
629
+
630
+ init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
631
+
632
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
633
+ # expand init_latents for batch_size
634
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
635
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
636
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
637
+ raise ValueError(
638
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
639
+ )
640
+ else:
641
+ init_latents = torch.cat([init_latents], dim=0)
642
+
643
+ shape = init_latents.shape
644
+ init_latents_orig = init_latents
645
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
646
+
647
+ # get latents
648
+ init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
649
+ latents = init_latents.to(device=device, dtype=dtype)
650
+
651
+ return latents, init_latents_orig, noise
652
+
653
+ def prepare_mask_latents(
654
+ self, mask, masked_image, batch_size, num_images_per_prompt, height, width, dtype, device, generator
655
+ ):
656
+ # resize the mask to latents shape as we concatenate the mask to the latents
657
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
658
+ # and half precision
659
+ mask = torch.nn.functional.interpolate(
660
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
661
+ )
662
+ mask = mask.to(device=device, dtype=dtype)
663
+
664
+ batch_size = batch_size * num_images_per_prompt
665
+
666
+ masked_image = masked_image.to(device=device, dtype=dtype)
667
+
668
+ if masked_image.shape[1] == 4:
669
+ masked_image_latents = masked_image
670
+ else:
671
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
672
+
673
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
674
+ if mask.shape[0] < batch_size:
675
+ if not batch_size % mask.shape[0] == 0:
676
+ raise ValueError(
677
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
678
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
679
+ " of masks that you pass is divisible by the total requested batch size."
680
+ )
681
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
682
+ if masked_image_latents.shape[0] < batch_size:
683
+ if not batch_size % masked_image_latents.shape[0] == 0:
684
+ raise ValueError(
685
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
686
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
687
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
688
+ )
689
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
690
+
691
+ # mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
692
+ # masked_image_latents = (
693
+ # torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
694
+ # )
695
+
696
+ # aligning device to prevent device errors when concating it with the latent model input
697
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
698
+ return mask, masked_image_latents
699
+
700
+ @property
701
+ def guidance_scale(self):
702
+ return self._guidance_scale
703
+
704
+ @property
705
+ def clip_skip(self):
706
+ return self._clip_skip
707
+
708
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
709
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
710
+ # corresponds to doing no classifier free guidance.
711
+ @property
712
+ def do_classifier_free_guidance(self):
713
+ return self._guidance_scale > 1
714
+
715
+ @property
716
+ def num_timesteps(self):
717
+ return self._num_timesteps
718
+
719
+ @property
720
+ def interrupt(self):
721
+ return self._interrupt
722
+
723
+ @torch.no_grad()
724
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
725
+ def __call__(
726
+ self,
727
+ prompt: Union[str, List[str]] = None,
728
+ prompt_2: Optional[Union[str, List[str]]] = None,
729
+ prompt_3: Optional[Union[str, List[str]]] = None,
730
+ height: int = None,
731
+ width: int = None,
732
+ image: PipelineImageInput = None,
733
+ mask_image: PipelineImageInput = None,
734
+ masked_image_latents: PipelineImageInput = None,
735
+ strength: float = 0.6,
736
+ num_inference_steps: int = 50,
737
+ timesteps: List[int] = None,
738
+ guidance_scale: float = 7.0,
739
+ negative_prompt: Optional[Union[str, List[str]]] = None,
740
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
741
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
742
+ num_images_per_prompt: Optional[int] = 1,
743
+ add_predicted_noise: Optional[bool] = False,
744
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
745
+ latents: Optional[torch.FloatTensor] = None,
746
+ prompt_embeds: Optional[torch.FloatTensor] = None,
747
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
748
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
749
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
750
+ output_type: Optional[str] = "pil",
751
+ return_dict: bool = True,
752
+ clip_skip: Optional[int] = None,
753
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
754
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
755
+ max_sequence_length: int = 256,
756
+ ):
757
+ r"""
758
+ Function invoked when calling the pipeline for generation.
759
+
760
+ Args:
761
+ prompt (`str` or `List[str]`, *optional*):
762
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
763
+ instead.
764
+ prompt_2 (`str` or `List[str]`, *optional*):
765
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
766
+ will be used instead
767
+ prompt_3 (`str` or `List[str]`, *optional*):
768
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
769
+ will be used instead
770
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
771
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
772
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
773
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
774
+ num_inference_steps (`int`, *optional*, defaults to 50):
775
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
776
+ expense of slower inference.
777
+ timesteps (`List[int]`, *optional*):
778
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
779
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
780
+ passed will be used. Must be in descending order.
781
+ guidance_scale (`float`, *optional*, defaults to 5.0):
782
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
783
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
784
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
785
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
786
+ usually at the expense of lower image quality.
787
+ negative_prompt (`str` or `List[str]`, *optional*):
788
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
789
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
790
+ less than `1`).
791
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
792
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
793
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
794
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
795
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
796
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
797
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
798
+ The number of images to generate per prompt.
799
+ add_predicted_noise (`bool`, *optional*, defaults to True):
800
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
801
+ the reverse diffusion process
802
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
803
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
804
+ to make generation deterministic.
805
+ latents (`torch.FloatTensor`, *optional*):
806
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
807
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
808
+ tensor will ge generated by sampling using the supplied random `generator`.
809
+ prompt_embeds (`torch.FloatTensor`, *optional*):
810
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
811
+ provided, text embeddings will be generated from `prompt` input argument.
812
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
813
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
814
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
815
+ argument.
816
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
817
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
818
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
819
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
820
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
821
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
822
+ input argument.
823
+ output_type (`str`, *optional*, defaults to `"pil"`):
824
+ The output format of the generate image. Choose between
825
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
826
+ return_dict (`bool`, *optional*, defaults to `True`):
827
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
828
+ of a plain tuple.
829
+ callback_on_step_end (`Callable`, *optional*):
830
+ A function that calls at the end of each denoising steps during the inference. The function is called
831
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
832
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
833
+ `callback_on_step_end_tensor_inputs`.
834
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
835
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
836
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
837
+ `._callback_tensor_inputs` attribute of your pipeline class.
838
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
839
+
840
+ Examples:
841
+
842
+ Returns:
843
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
844
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
845
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
846
+ """
847
+
848
+ # 1. Check inputs. Raise error if not correct
849
+ self.check_inputs(
850
+ prompt,
851
+ prompt_2,
852
+ prompt_3,
853
+ strength,
854
+ negative_prompt=negative_prompt,
855
+ negative_prompt_2=negative_prompt_2,
856
+ negative_prompt_3=negative_prompt_3,
857
+ prompt_embeds=prompt_embeds,
858
+ negative_prompt_embeds=negative_prompt_embeds,
859
+ pooled_prompt_embeds=pooled_prompt_embeds,
860
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
861
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
862
+ max_sequence_length=max_sequence_length,
863
+ )
864
+
865
+ self._guidance_scale = guidance_scale
866
+ self._clip_skip = clip_skip
867
+ self._interrupt = False
868
+
869
+ # 2. Define call parameters
870
+ if prompt is not None and isinstance(prompt, str):
871
+ batch_size = 1
872
+ elif prompt is not None and isinstance(prompt, list):
873
+ batch_size = len(prompt)
874
+ else:
875
+ batch_size = prompt_embeds.shape[0]
876
+
877
+ device = self._execution_device
878
+
879
+ (
880
+ prompt_embeds,
881
+ negative_prompt_embeds,
882
+ pooled_prompt_embeds,
883
+ negative_pooled_prompt_embeds,
884
+ ) = self.encode_prompt(
885
+ prompt=prompt,
886
+ prompt_2=prompt_2,
887
+ prompt_3=prompt_3,
888
+ negative_prompt=negative_prompt,
889
+ negative_prompt_2=negative_prompt_2,
890
+ negative_prompt_3=negative_prompt_3,
891
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
892
+ prompt_embeds=prompt_embeds,
893
+ negative_prompt_embeds=negative_prompt_embeds,
894
+ pooled_prompt_embeds=pooled_prompt_embeds,
895
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
896
+ device=device,
897
+ clip_skip=self.clip_skip,
898
+ num_images_per_prompt=num_images_per_prompt,
899
+ max_sequence_length=max_sequence_length,
900
+ )
901
+
902
+ if self.do_classifier_free_guidance:
903
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
904
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
905
+
906
+ # 3. Preprocess image
907
+ image = self.image_processor.preprocess(image, height, width)
908
+
909
+ # 4. Prepare timesteps
910
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
911
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
912
+ latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
913
+
914
+ # 5. Prepare latent variables
915
+ if latents is None:
916
+ latents, init_latents_orig, noise = self.prepare_latents(
917
+ image,
918
+ latent_timestep,
919
+ batch_size,
920
+ num_images_per_prompt,
921
+ prompt_embeds.dtype,
922
+ device,
923
+ generator,
924
+ )
925
+
926
+ # 5.1. Prepare masked latent variables
927
+ mask_condition = self.mask_processor.preprocess(mask_image, height, width)
928
+
929
+ if masked_image_latents is None:
930
+ masked_image = image * (mask_condition < 0.5)
931
+ else:
932
+ masked_image = masked_image_latents
933
+
934
+ mask, masked_image_latents = self.prepare_mask_latents(
935
+ mask_condition,
936
+ masked_image,
937
+ batch_size,
938
+ num_images_per_prompt,
939
+ height,
940
+ width,
941
+ prompt_embeds.dtype,
942
+ device,
943
+ generator
944
+ )
945
+
946
+ # 6. Denoising loop
947
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
948
+ self._num_timesteps = len(timesteps)
949
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
950
+ for i, t in enumerate(timesteps):
951
+ if self.interrupt:
952
+ continue
953
+
954
+ # expand the latents if we are doing classifier free guidance
955
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
956
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
957
+ timestep = t.expand(latent_model_input.shape[0])
958
+
959
+ noise_pred = self.transformer(
960
+ hidden_states=latent_model_input,
961
+ timestep=timestep,
962
+ encoder_hidden_states=prompt_embeds,
963
+ pooled_projections=pooled_prompt_embeds,
964
+ return_dict=False,
965
+ )[0]
966
+
967
+ # perform guidance
968
+ if self.do_classifier_free_guidance:
969
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
970
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
971
+
972
+ # compute the previous noisy sample x_t -> x_t-1
973
+ latents_dtype = latents.dtype
974
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
975
+
976
+ if latents.dtype != latents_dtype:
977
+ if torch.backends.mps.is_available():
978
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
979
+ latents = latents.to(latents_dtype)
980
+
981
+ if callback_on_step_end is not None:
982
+ callback_kwargs = {}
983
+ for k in callback_on_step_end_tensor_inputs:
984
+ callback_kwargs[k] = locals()[k]
985
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
986
+
987
+ latents = callback_outputs.pop("latents", latents)
988
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
989
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
990
+ negative_pooled_prompt_embeds = callback_outputs.pop(
991
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
992
+ )
993
+
994
+ if add_predicted_noise:
995
+ init_latents_proper = self.scheduler.scale_noise(
996
+ init_latents_orig, torch.tensor([t]), noise_pred_uncond
997
+ )
998
+ else:
999
+ init_latents_proper = self.scheduler.scale_noise(init_latents_orig, torch.tensor([t]), noise)
1000
+
1001
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
1002
+
1003
+ # call the callback, if provided
1004
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1005
+ progress_bar.update()
1006
+
1007
+ if XLA_AVAILABLE:
1008
+ xm.mark_step()
1009
+
1010
+ latents = (init_latents_orig * mask) + (latents * (1 - mask))
1011
+
1012
+ if output_type == "latent":
1013
+ image = latents
1014
+
1015
+ else:
1016
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1017
+
1018
+ image = self.vae.decode(latents, return_dict=False)[0]
1019
+ image = self.image_processor.postprocess(image, output_type=output_type)
1020
+
1021
+ # Offload all models
1022
+ self.maybe_free_model_hooks()
1023
+
1024
+ if not return_dict:
1025
+ return (image,)
1026
+
1027
+ return StableDiffusion3PipelineOutput(images=image)