shachargluska commited on
Commit
c95c83d
1 Parent(s): 157e2ba

Add SD1.5 example to README.MD

Browse files

Thank you for your work!
I was curious to see the decoder at work so I wrote a small example of running it with SD1.5 adapter.
I hope that's alright with you.


![image.png](https://cdn-uploads.huggingface.co/production/uploads/668cd89caf57d6e4b5442719/8AJ8fVwwTFfFIca-wUyVx.png)

Files changed (1) hide show
  1. README.md +44 -0
README.md CHANGED
@@ -20,6 +20,50 @@ on real images. Plus it is MIT licensed so you can do whatever you want with it.
20
  ### Compare
21
  Check out the comparison at [imgsli](https://imgsli.com/Mjc2MjA3).
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  ### What do I do with this?
25
 
 
20
  ### Compare
21
  Check out the comparison at [imgsli](https://imgsli.com/Mjc2MjA3).
22
 
23
+ ### Use with SD1.5 (Diffusers)
24
+ ```py
25
+ import torch
26
+ from diffusers import AutoencoderKL, StableDiffusionPipeline
27
+ from huggingface_hub import hf_hub_download
28
+ from safetensors.torch import load_file
29
+
30
+ model_id = "runwayml/stable-diffusion-v1-5"
31
+ decoder_id = "ostris/vae-kl-f8-d16"
32
+ adapter_id = "ostris/16ch-VAE-Adapters"
33
+ adapter_ckpt = "16ch-VAE-Adapter-SD15-alpha.safetensors"
34
+ dtype = torch.float16
35
+
36
+ vae = AutoencoderKL.from_pretrained(decoder_id, torch_dtype=dtype)
37
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float16)
38
+
39
+ ckpt_file = hf_hub_download(adapter_id, adapter_ckpt)
40
+ ckpt = load_file(ckpt_file)
41
+
42
+ lora_state_dict = {k: v for k, v in ckpt.items() if "lora" in k}
43
+ unet_state_dict = {k.replace("unet_", ""): v for k, v in ckpt.items() if "unet_" in k}
44
+
45
+ pipe.unet.conv_in = torch.nn.Conv2d(16, 320, 3, 1, 1)
46
+ pipe.unet.conv_out = torch.nn.Conv2d(320, 16, 3, 1, 1)
47
+ pipe.unet.load_state_dict(unet_state_dict, strict=False)
48
+ pipe.unet.conv_in.to(dtype)
49
+ pipe.unet.conv_out.to(dtype)
50
+ pipe.unet.config.in_channels = 16
51
+ pipe.unet.config.out_channels = 16
52
+
53
+ pipe.load_lora_weights(lora_state_dict)
54
+ pipe.fuse_lora()
55
+
56
+ pipe = pipe.to("cuda")
57
+ prompt = "a photo of an astronaut riding a horse on mars"
58
+ negative_prompt = (
59
+ "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame,"
60
+ "extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature,"
61
+ "cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face"
62
+ )
63
+ image = pipe(prompt, negative_prompt=negative_prompt).images[0]
64
+
65
+ image.save("astronaut_rides_horse.png")
66
+ ```
67
 
68
  ### What do I do with this?
69