ljp commited on
Commit
b10061f
1 Parent(s): 9bc8f3c

Create demo.py

Browse files
Files changed (1) hide show
  1. demo.py +53 -0
demo.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.utils import load_image, check_min_version
2
+ import torch
3
+
4
+ # Local File
5
+ from pipeline_sd3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline, one_image_and_mask
6
+ from controlnet_sd3 import SD3ControlNetModel
7
+
8
+ check_min_version("0.29.2")
9
+
10
+ # Build model
11
+ controlnet = SD3ControlNetModel.from_pretrained(
12
+ "alimama-creative/SD3-controlnet-inpaint",
13
+ use_safetensors=True,
14
+ )
15
+ pipe = StableDiffusion3ControlNetInpaintingPipeline.from_pretrained(
16
+ "stabilityai/stable-diffusion-3-medium-diffusers",
17
+ controlnet=controlnet,
18
+ torch_dtype=torch.float16,
19
+ )
20
+ pipe.text_encoder.to(torch.float16)
21
+ pipe.controlnet.to(torch.float16)
22
+ pipe.to("cuda")
23
+
24
+ # Load image
25
+ image = load_image(
26
+ "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/blob/main/prod.png"
27
+ )
28
+ mask = load_image(
29
+ "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/blob/main/mask.jpeg"
30
+ )
31
+
32
+ # Set args
33
+ width = 1024
34
+ height = 1024
35
+ prompt="a woman wearing a white jacket, black hat and black pants is standing in a field, the hat writes SD3"
36
+ generator = torch.Generator(device="cuda").manual_seed(24)
37
+ input_dict = one_image_and_mask(image, mask, size=(width, height), latent_scale=pipe.vae_scale_factor, invert_mask = True)
38
+
39
+ # Inference
40
+ res_image = pipe(
41
+ negative_prompt='deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW',
42
+ prompt=prompt,
43
+ height=height,
44
+ width=width,
45
+ control_image= input_dict['pil_masked_image'], # H, W, C,
46
+ control_mask=input_dict["mask"] > 0.5, # B,1,H,W
47
+ num_inference_steps=28,
48
+ generator=generator,
49
+ controlnet_conditioning_scale=0.95,
50
+ guidance_scale=7,
51
+ ).images[0]
52
+
53
+ res_image.save(f'res.png')