IrohXu commited on
Commit
9729d10
1 Parent(s): 3c12e7b

update the code

Browse files
demo.py CHANGED
@@ -8,7 +8,6 @@ def preprocess_image(image):
8
  image = image.convert("RGB")
9
  image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
10
  image = transforms.ToTensor()(image)
11
- image = image * 2 - 1
12
  image = image.unsqueeze(0).to("cuda")
13
  return image
14
 
@@ -24,7 +23,7 @@ pipe = StableDiffusion3InpaintPipeline.from_pretrained(
24
  torch_dtype=torch.float16,
25
  ).to("cuda")
26
 
27
- prompt = "Face of a cat, high resolution, sitting on a park bench"
28
  source_image = load_image(
29
  "./overture-creations-5sI6fQgYIuo.png"
30
  )
@@ -38,10 +37,10 @@ mask = preprocess_mask(
38
  image = pipe(
39
  prompt=prompt,
40
  image=source,
41
- mask_image=1-mask,
42
  height=1024,
43
  width=1024,
44
- num_inference_steps=28,
45
  guidance_scale=7.0,
46
  strength=0.6,
47
  ).images[0]
 
8
  image = image.convert("RGB")
9
  image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
10
  image = transforms.ToTensor()(image)
 
11
  image = image.unsqueeze(0).to("cuda")
12
  return image
13
 
 
23
  torch_dtype=torch.float16,
24
  ).to("cuda")
25
 
26
+ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
27
  source_image = load_image(
28
  "./overture-creations-5sI6fQgYIuo.png"
29
  )
 
37
  image = pipe(
38
  prompt=prompt,
39
  image=source,
40
+ mask_image=mask,
41
  height=1024,
42
  width=1024,
43
+ num_inference_steps=50,
44
  guidance_scale=7.0,
45
  strength=0.6,
46
  ).images[0]
overture-creations-5sI6fQgYIuo_output.jpg CHANGED
pipeline_stable_diffusion_3_inpaint.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  # Licensed under the Apache License, Version 2.0 (the "License");
2
  # you may not use this file except in compliance with the License.
3
  # You may obtain a copy of the License at
@@ -13,7 +15,6 @@
13
  import inspect
14
  from typing import Callable, Dict, List, Optional, Union
15
 
16
- import PIL.Image
17
  import torch
18
  from transformers import (
19
  CLIPTextModelWithProjection,
@@ -22,6 +23,7 @@ from transformers import (
22
  T5TokenizerFast,
23
  )
24
 
 
25
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
26
  from diffusers.models.autoencoders import AutoencoderKL
27
  from diffusers.models.transformers import SD3Transformer2DModel
@@ -50,21 +52,20 @@ EXAMPLE_DOC_STRING = """
50
  Examples:
51
  ```py
52
  >>> import torch
53
-
54
- >>> from diffusers import AutoPipelineForImage2Image
55
  >>> from diffusers.utils import load_image
56
 
57
- >>> device = "cuda"
58
- >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
59
- >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
60
- >>> pipe = pipe.to(device)
61
-
62
- >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
63
- >>> init_image = load_image(url).resize((512, 512))
64
-
65
- >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
66
-
67
- >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]
68
  ```
69
  """
70
 
@@ -211,7 +212,11 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
211
  vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
212
  )
213
  self.mask_processor = VaeImageProcessor(
214
- vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True
 
 
 
 
215
  )
216
  self.tokenizer_max_length = self.tokenizer.model_max_length
217
  self.default_sample_size = self.transformer.config.sample_size
@@ -499,6 +504,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
499
 
500
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
501
 
 
502
  def check_inputs(
503
  self,
504
  prompt,
@@ -588,6 +594,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
588
  if max_sequence_length is not None and max_sequence_length > 512:
589
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
590
 
 
591
  def get_timesteps(self, num_inference_steps, strength, device):
592
  # get the original timestep using init_timestep
593
  init_timestep = min(num_inference_steps * strength, num_inference_steps)
@@ -599,59 +606,93 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
599
 
600
  return timesteps, num_inference_steps - t_start
601
 
602
- def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
603
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  raise ValueError(
605
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
 
606
  )
607
 
608
- image = image.to(device=device, dtype=dtype)
 
609
 
610
- batch_size = batch_size * num_images_per_prompt
611
- if image.shape[1] == self.vae.config.latent_channels:
612
- init_latents = image
 
 
613
 
 
 
 
 
614
  else:
615
- if isinstance(generator, list) and len(generator) != batch_size:
616
- raise ValueError(
617
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
618
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
619
- )
620
 
621
- elif isinstance(generator, list):
622
- init_latents = [
623
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
624
- for i in range(batch_size)
625
- ]
626
- init_latents = torch.cat(init_latents, dim=0)
627
- else:
628
- init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
629
 
630
- init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
 
631
 
632
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
633
- # expand init_latents for batch_size
634
- additional_image_per_prompt = batch_size // init_latents.shape[0]
635
- init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
636
- elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
637
- raise ValueError(
638
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
639
- )
 
 
 
 
640
  else:
641
- init_latents = torch.cat([init_latents], dim=0)
642
 
643
- shape = init_latents.shape
644
- init_latents_orig = init_latents
645
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
646
 
647
- # get latents
648
- init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
649
- latents = init_latents.to(device=device, dtype=dtype)
650
 
651
- return latents, init_latents_orig, noise
652
-
653
  def prepare_mask_latents(
654
- self, mask, masked_image, batch_size, num_images_per_prompt, height, width, dtype, device, generator
 
 
 
 
 
 
 
 
 
 
655
  ):
656
  # resize the mask to latents shape as we concatenate the mask to the latents
657
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
@@ -660,16 +701,18 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
660
  mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
661
  )
662
  mask = mask.to(device=device, dtype=dtype)
663
-
664
  batch_size = batch_size * num_images_per_prompt
665
 
666
  masked_image = masked_image.to(device=device, dtype=dtype)
667
 
668
- if masked_image.shape[1] == 4:
669
  masked_image_latents = masked_image
670
  else:
671
  masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
672
 
 
 
673
  # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
674
  if mask.shape[0] < batch_size:
675
  if not batch_size % mask.shape[0] == 0:
@@ -688,15 +731,15 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
688
  )
689
  masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
690
 
691
- # mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
692
- # masked_image_latents = (
693
- # torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
694
- # )
695
 
696
  # aligning device to prevent device errors when concating it with the latent model input
697
  masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
698
  return mask, masked_image_latents
699
-
700
  @property
701
  def guidance_scale(self):
702
  return self._guidance_scale
@@ -727,11 +770,12 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
727
  prompt: Union[str, List[str]] = None,
728
  prompt_2: Optional[Union[str, List[str]]] = None,
729
  prompt_3: Optional[Union[str, List[str]]] = None,
730
- height: int = None,
731
- width: int = None,
732
  image: PipelineImageInput = None,
733
  mask_image: PipelineImageInput = None,
734
  masked_image_latents: PipelineImageInput = None,
 
 
 
735
  strength: float = 0.6,
736
  num_inference_steps: int = 50,
737
  timesteps: List[int] = None,
@@ -740,13 +784,12 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
740
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
741
  negative_prompt_3: Optional[Union[str, List[str]]] = None,
742
  num_images_per_prompt: Optional[int] = 1,
743
- add_predicted_noise: Optional[bool] = False,
744
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
745
- latents: Optional[torch.FloatTensor] = None,
746
- prompt_embeds: Optional[torch.FloatTensor] = None,
747
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
748
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
749
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
750
  output_type: Optional[str] = "pil",
751
  return_dict: bool = True,
752
  clip_skip: Optional[int] = None,
@@ -767,10 +810,39 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
767
  prompt_3 (`str` or `List[str]`, *optional*):
768
  The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
769
  will be used instead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
771
  The height in pixels of the generated image. This is set to 1024 by default for the best results.
772
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
773
  The width in pixels of the generated image. This is set to 1024 by default for the best results.
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  num_inference_steps (`int`, *optional*, defaults to 50):
775
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
776
  expense of slower inference.
@@ -796,9 +868,6 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
796
  `text_encoder_3`. If not defined, `negative_prompt` is used instead
797
  num_images_per_prompt (`int`, *optional*, defaults to 1):
798
  The number of images to generate per prompt.
799
- add_predicted_noise (`bool`, *optional*, defaults to True):
800
- Use predicted noise instead of random noise when constructing noisy versions of the original image in
801
- the reverse diffusion process
802
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
803
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
804
  to make generation deterministic.
@@ -845,6 +914,12 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
845
  `tuple`. When returning a tuple, the first element is a list with the generated images.
846
  """
847
 
 
 
 
 
 
 
848
  # 1. Check inputs. Raise error if not correct
849
  self.check_inputs(
850
  prompt,
@@ -903,34 +978,70 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
903
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
904
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
905
 
906
- # 3. Preprocess image
907
- image = self.image_processor.preprocess(image, height, width)
908
-
909
- # 4. Prepare timesteps
910
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
911
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
 
 
 
 
 
 
912
  latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
914
  # 5. Prepare latent variables
915
- if latents is None:
916
- latents, init_latents_orig, noise = self.prepare_latents(
917
- image,
918
- latent_timestep,
919
- batch_size,
920
- num_images_per_prompt,
921
- prompt_embeds.dtype,
922
- device,
923
- generator,
924
- )
925
-
926
- # 5.1. Prepare masked latent variables
927
- mask_condition = self.mask_processor.preprocess(mask_image, height, width)
928
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
  if masked_image_latents is None:
930
- masked_image = image * (mask_condition < 0.5)
931
  else:
932
  masked_image = masked_image_latents
933
-
934
  mask, masked_image_latents = self.prepare_mask_latents(
935
  mask_condition,
936
  masked_image,
@@ -940,10 +1051,32 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
940
  width,
941
  prompt_embeds.dtype,
942
  device,
943
- generator
 
944
  )
945
 
946
- # 6. Denoising loop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
948
  self._num_timesteps = len(timesteps)
949
  with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -955,7 +1088,10 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
955
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
956
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
957
  timestep = t.expand(latent_model_input.shape[0])
958
-
 
 
 
959
  noise_pred = self.transformer(
960
  hidden_states=latent_model_input,
961
  timestep=timestep,
@@ -972,6 +1108,20 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
972
  # compute the previous noisy sample x_t -> x_t-1
973
  latents_dtype = latents.dtype
974
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
975
 
976
  if latents.dtype != latents_dtype:
977
  if torch.backends.mps.is_available():
@@ -990,15 +1140,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
990
  negative_pooled_prompt_embeds = callback_outputs.pop(
991
  "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
992
  )
993
-
994
- if add_predicted_noise:
995
- init_latents_proper = self.scheduler.scale_noise(
996
- init_latents_orig, torch.tensor([t]), noise_pred_uncond
997
- )
998
- else:
999
- init_latents_proper = self.scheduler.scale_noise(init_latents_orig, torch.tensor([t]), noise)
1000
-
1001
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
1002
 
1003
  # call the callback, if provided
1004
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1007,16 +1150,19 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
1007
  if XLA_AVAILABLE:
1008
  xm.mark_step()
1009
 
1010
- latents = (init_latents_orig * mask) + (latents * (1 - mask))
1011
-
1012
- if output_type == "latent":
 
 
1013
  image = latents
1014
 
1015
- else:
1016
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
 
1017
 
1018
- image = self.vae.decode(latents, return_dict=False)[0]
1019
- image = self.image_processor.postprocess(image, output_type=output_type)
1020
 
1021
  # Offload all models
1022
  self.maybe_free_model_hooks()
@@ -1024,4 +1170,4 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline):
1024
  if not return_dict:
1025
  return (image,)
1026
 
1027
- return StableDiffusion3PipelineOutput(images=image)
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team and IrohXu. All rights reserved.
2
+ #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
  # You may obtain a copy of the License at
 
15
  import inspect
16
  from typing import Callable, Dict, List, Optional, Union
17
 
 
18
  import torch
19
  from transformers import (
20
  CLIPTextModelWithProjection,
 
23
  T5TokenizerFast,
24
  )
25
 
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
28
  from diffusers.models.autoencoders import AutoencoderKL
29
  from diffusers.models.transformers import SD3Transformer2DModel
 
52
  Examples:
53
  ```py
54
  >>> import torch
55
+ >>> from diffusers import StableDiffusion3InpaintPipeline
 
56
  >>> from diffusers.utils import load_image
57
 
58
+ >>> pipe = StableDiffusion3InpaintPipeline.from_pretrained(
59
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
60
+ ... )
61
+ >>> pipe.to("cuda")
62
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
63
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
64
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
65
+ >>> source = load_image(img_url)
66
+ >>> mask = load_image(mask_url)
67
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0]
68
+ >>> image.save("sd3_inpainting.png")
69
  ```
70
  """
71
 
 
212
  vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
213
  )
214
  self.mask_processor = VaeImageProcessor(
215
+ vae_scale_factor=self.vae_scale_factor,
216
+ vae_latent_channels=self.vae.config.latent_channels,
217
+ do_normalize=False,
218
+ do_binarize=True,
219
+ do_convert_grayscale=True,
220
  )
221
  self.tokenizer_max_length = self.tokenizer.model_max_length
222
  self.default_sample_size = self.transformer.config.sample_size
 
504
 
505
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
506
 
507
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.check_inputs
508
  def check_inputs(
509
  self,
510
  prompt,
 
594
  if max_sequence_length is not None and max_sequence_length > 512:
595
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
596
 
597
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
598
  def get_timesteps(self, num_inference_steps, strength, device):
599
  # get the original timestep using init_timestep
600
  init_timestep = min(num_inference_steps * strength, num_inference_steps)
 
606
 
607
  return timesteps, num_inference_steps - t_start
608
 
609
+ def prepare_latents(
610
+ self,
611
+ batch_size,
612
+ num_channels_latents,
613
+ height,
614
+ width,
615
+ dtype,
616
+ device,
617
+ generator,
618
+ latents=None,
619
+ image=None,
620
+ timestep=None,
621
+ is_strength_max=True,
622
+ return_noise=False,
623
+ return_image_latents=False,
624
+ ):
625
+ shape = (
626
+ batch_size,
627
+ num_channels_latents,
628
+ int(height) // self.vae_scale_factor,
629
+ int(width) // self.vae_scale_factor,
630
+ )
631
+ if isinstance(generator, list) and len(generator) != batch_size:
632
+ raise ValueError(
633
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
634
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
635
+ )
636
+
637
+ if (image is None or timestep is None) and not is_strength_max:
638
  raise ValueError(
639
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
640
+ "However, either the image or the noise timestep has not been provided."
641
  )
642
 
643
+ if return_image_latents or (latents is None and not is_strength_max):
644
+ image = image.to(device=device, dtype=dtype)
645
 
646
+ if image.shape[1] == 16:
647
+ image_latents = image
648
+ else:
649
+ image_latents = self._encode_vae_image(image=image, generator=generator)
650
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
651
 
652
+ if latents is None:
653
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
654
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
655
+ latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise)
656
  else:
657
+ noise = latents.to(device)
658
+ latents = noise
 
 
 
659
 
660
+ outputs = (latents,)
 
 
 
 
 
 
 
661
 
662
+ if return_noise:
663
+ outputs += (noise,)
664
 
665
+ if return_image_latents:
666
+ outputs += (image_latents,)
667
+
668
+ return outputs
669
+
670
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
671
+ if isinstance(generator, list):
672
+ image_latents = [
673
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
674
+ for i in range(image.shape[0])
675
+ ]
676
+ image_latents = torch.cat(image_latents, dim=0)
677
  else:
678
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
679
 
680
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
 
 
681
 
682
+ return image_latents
 
 
683
 
 
 
684
  def prepare_mask_latents(
685
+ self,
686
+ mask,
687
+ masked_image,
688
+ batch_size,
689
+ num_images_per_prompt,
690
+ height,
691
+ width,
692
+ dtype,
693
+ device,
694
+ generator,
695
+ do_classifier_free_guidance,
696
  ):
697
  # resize the mask to latents shape as we concatenate the mask to the latents
698
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
 
701
  mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
702
  )
703
  mask = mask.to(device=device, dtype=dtype)
704
+
705
  batch_size = batch_size * num_images_per_prompt
706
 
707
  masked_image = masked_image.to(device=device, dtype=dtype)
708
 
709
+ if masked_image.shape[1] == 16:
710
  masked_image_latents = masked_image
711
  else:
712
  masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
713
 
714
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
715
+
716
  # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
717
  if mask.shape[0] < batch_size:
718
  if not batch_size % mask.shape[0] == 0:
 
731
  )
732
  masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
733
 
734
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
735
+ masked_image_latents = (
736
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
737
+ )
738
 
739
  # aligning device to prevent device errors when concating it with the latent model input
740
  masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
741
  return mask, masked_image_latents
742
+
743
  @property
744
  def guidance_scale(self):
745
  return self._guidance_scale
 
770
  prompt: Union[str, List[str]] = None,
771
  prompt_2: Optional[Union[str, List[str]]] = None,
772
  prompt_3: Optional[Union[str, List[str]]] = None,
 
 
773
  image: PipelineImageInput = None,
774
  mask_image: PipelineImageInput = None,
775
  masked_image_latents: PipelineImageInput = None,
776
+ height: int = None,
777
+ width: int = None,
778
+ padding_mask_crop: Optional[int] = None,
779
  strength: float = 0.6,
780
  num_inference_steps: int = 50,
781
  timesteps: List[int] = None,
 
784
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
785
  negative_prompt_3: Optional[Union[str, List[str]]] = None,
786
  num_images_per_prompt: Optional[int] = 1,
 
787
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
788
+ latents: Optional[torch.Tensor] = None,
789
+ prompt_embeds: Optional[torch.Tensor] = None,
790
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
791
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
792
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
793
  output_type: Optional[str] = "pil",
794
  return_dict: bool = True,
795
  clip_skip: Optional[int] = None,
 
810
  prompt_3 (`str` or `List[str]`, *optional*):
811
  The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
812
  will be used instead
813
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
814
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
815
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
816
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
817
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
818
+ latents as `image`, but if passing latents directly it is not encoded again.
819
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
820
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
821
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
822
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
823
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
824
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
825
+ 1)`, or `(H, W)`.
826
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
827
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
828
+ latents tensor will ge generated by `mask_image`.
829
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
830
  The height in pixels of the generated image. This is set to 1024 by default for the best results.
831
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
832
  The width in pixels of the generated image. This is set to 1024 by default for the best results.
833
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
834
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
835
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
836
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
837
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
838
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
839
+ the image is large and contain information irrelevant for inpainting, such as background.
840
+ strength (`float`, *optional*, defaults to 1.0):
841
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
842
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
843
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
844
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
845
+ essentially ignores `image`.
846
  num_inference_steps (`int`, *optional*, defaults to 50):
847
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
848
  expense of slower inference.
 
868
  `text_encoder_3`. If not defined, `negative_prompt` is used instead
869
  num_images_per_prompt (`int`, *optional*, defaults to 1):
870
  The number of images to generate per prompt.
 
 
 
871
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
872
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
873
  to make generation deterministic.
 
914
  `tuple`. When returning a tuple, the first element is a list with the generated images.
915
  """
916
 
917
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
918
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
919
+
920
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
921
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
922
+
923
  # 1. Check inputs. Raise error if not correct
924
  self.check_inputs(
925
  prompt,
 
978
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
979
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
980
 
981
+ # 3. Prepare timesteps
 
 
 
982
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
983
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
984
+ # check that number of inference steps is not < 1 - as this doesn't make sense
985
+ if num_inference_steps < 1:
986
+ raise ValueError(
987
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
988
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
989
+ )
990
  latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
991
 
992
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
993
+ is_strength_max = strength == 1.0
994
+
995
+ # 4. Preprocess mask and image
996
+ if padding_mask_crop is not None:
997
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
998
+ resize_mode = "fill"
999
+ else:
1000
+ crops_coords = None
1001
+ resize_mode = "default"
1002
+
1003
+ original_image = image
1004
+ init_image = self.image_processor.preprocess(
1005
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1006
+ )
1007
+ init_image = init_image.to(dtype=torch.float32)
1008
+
1009
  # 5. Prepare latent variables
1010
+ num_channels_latents = self.vae.config.latent_channels
1011
+ num_channels_transformer = self.transformer.config.in_channels
1012
+ return_image_latents = num_channels_transformer == 16
1013
+
1014
+ latents_outputs = self.prepare_latents(
1015
+ batch_size * num_images_per_prompt,
1016
+ num_channels_latents,
1017
+ height,
1018
+ width,
1019
+ prompt_embeds.dtype,
1020
+ device,
1021
+ generator,
1022
+ latents,
1023
+ image=init_image,
1024
+ timestep=latent_timestep,
1025
+ is_strength_max=is_strength_max,
1026
+ return_noise=True,
1027
+ return_image_latents=return_image_latents,
1028
+ )
1029
+
1030
+ if return_image_latents:
1031
+ latents, noise, image_latents = latents_outputs
1032
+ else:
1033
+ latents, noise = latents_outputs
1034
+
1035
+ # 6. Prepare mask latent variables
1036
+ mask_condition = self.mask_processor.preprocess(
1037
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1038
+ )
1039
+
1040
  if masked_image_latents is None:
1041
+ masked_image = init_image * (mask_condition < 0.5)
1042
  else:
1043
  masked_image = masked_image_latents
1044
+
1045
  mask, masked_image_latents = self.prepare_mask_latents(
1046
  mask_condition,
1047
  masked_image,
 
1051
  width,
1052
  prompt_embeds.dtype,
1053
  device,
1054
+ generator,
1055
+ self.do_classifier_free_guidance,
1056
  )
1057
 
1058
+ # match the inpainting pipeline and will be updated with input + mask inpainting model later
1059
+ if num_channels_transformer == 33:
1060
+ # default case for runwayml/stable-diffusion-inpainting
1061
+ num_channels_mask = mask.shape[1]
1062
+ num_channels_masked_image = masked_image_latents.shape[1]
1063
+ if (
1064
+ num_channels_latents + num_channels_mask + num_channels_masked_image
1065
+ != self.transformer.config.in_channels
1066
+ ):
1067
+ raise ValueError(
1068
+ f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
1069
+ f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1070
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1071
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1072
+ " `pipeline.transformer` or your `mask_image` or `image` input."
1073
+ )
1074
+ elif num_channels_transformer != 16:
1075
+ raise ValueError(
1076
+ f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}."
1077
+ )
1078
+
1079
+ # 7. Denoising loop
1080
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1081
  self._num_timesteps = len(timesteps)
1082
  with self.progress_bar(total=num_inference_steps) as progress_bar:
 
1088
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1089
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1090
  timestep = t.expand(latent_model_input.shape[0])
1091
+
1092
+ if num_channels_transformer == 33:
1093
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1094
+
1095
  noise_pred = self.transformer(
1096
  hidden_states=latent_model_input,
1097
  timestep=timestep,
 
1108
  # compute the previous noisy sample x_t -> x_t-1
1109
  latents_dtype = latents.dtype
1110
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1111
+ if num_channels_transformer == 16:
1112
+ init_latents_proper = image_latents
1113
+ if self.do_classifier_free_guidance:
1114
+ init_mask, _ = mask.chunk(2)
1115
+ else:
1116
+ init_mask = mask
1117
+
1118
+ if i < len(timesteps) - 1:
1119
+ noise_timestep = timesteps[i + 1]
1120
+ init_latents_proper = self.scheduler.scale_noise(
1121
+ init_latents_proper, torch.tensor([noise_timestep]), noise
1122
+ )
1123
+
1124
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1125
 
1126
  if latents.dtype != latents_dtype:
1127
  if torch.backends.mps.is_available():
 
1140
  negative_pooled_prompt_embeds = callback_outputs.pop(
1141
  "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1142
  )
1143
+ mask = callback_outputs.pop("mask", mask)
1144
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
 
 
 
 
 
 
 
1145
 
1146
  # call the callback, if provided
1147
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
1150
  if XLA_AVAILABLE:
1151
  xm.mark_step()
1152
 
1153
+ if not output_type == "latent":
1154
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1155
+ 0
1156
+ ]
1157
+ else:
1158
  image = latents
1159
 
1160
+ do_denormalize = [True] * image.shape[0]
1161
+
1162
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1163
 
1164
+ if padding_mask_crop is not None:
1165
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1166
 
1167
  # Offload all models
1168
  self.maybe_free_model_hooks()
 
1170
  if not return_dict:
1171
  return (image,)
1172
 
1173
+ return StableDiffusion3PipelineOutput(images=image)