pcuenq HF staff commited on
Commit
e28c01f
1 Parent(s): 5664f93
Files changed (2) hide show
  1. app.py +328 -0
  2. modules.py +178 -0
app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import open_clip
3
+ import torch
4
+ from PIL import Image
5
+ from open_clip import tokenizer
6
+ from rudalle import get_vae
7
+ from einops import rearrange
8
+ from modules import DenoiseUNet
9
+
10
+ model_id = "./model_600000.pt"
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ batch_size = 4
14
+ steps = 11
15
+ scale = 5
16
+
17
+
18
+ def to_pil(images):
19
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
20
+ images = (images * 255).round().astype("uint8")
21
+ images = [Image.fromarray(image) for image in images]
22
+ return images
23
+
24
+ def log(t, eps=1e-20):
25
+ return torch.log(t + eps)
26
+
27
+ def gumbel_noise(t):
28
+ noise = torch.zeros_like(t).uniform_(0, 1)
29
+ return -log(-log(noise))
30
+
31
+ def gumbel_sample(t, temperature=1., dim=-1):
32
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
33
+
34
+ def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):
35
+ with torch.inference_mode():
36
+ r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)
37
+ temperatures = torch.linspace(temp_range[0], temp_range[1], T)
38
+ preds = []
39
+ if x is None:
40
+ x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)
41
+ elif mask is not None:
42
+ noise = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)
43
+ x = noise * mask + (1-mask) * x
44
+ init_x = x.clone()
45
+ for i in range(starting_t, T):
46
+ if renoise_mode == 'prev':
47
+ prev_x = x.clone()
48
+ r, temp = r_range[i], temperatures[i]
49
+ logits = model(x, c, r)
50
+ if classifier_free_scale >= 0:
51
+ logits_uncond = model(x, torch.zeros_like(c), r)
52
+ logits = torch.lerp(logits_uncond, logits, classifier_free_scale)
53
+ x = logits
54
+ x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))
55
+ if typical_filtering:
56
+ x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
57
+ x_flat_norm_p = torch.exp(x_flat_norm)
58
+ entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
59
+
60
+ c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
61
+ c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
62
+ x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
63
+
64
+ last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
65
+ sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1))
66
+ if typical_min_tokens > 1:
67
+ sorted_indices_to_remove[..., :typical_min_tokens] = 0
68
+ indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)
69
+ x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
70
+ # x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0]
71
+ x_flat = gumbel_sample(x_flat, temperature=temp)
72
+ x = x_flat.view(x.size(0), *x.shape[2:])
73
+ if mask is not None:
74
+ x = x * mask + (1-mask) * init_x
75
+ if i < renoise_steps:
76
+ if renoise_mode == 'start':
77
+ x, _ = model.add_noise(x, r_range[i+1], random_x=init_x)
78
+ elif renoise_mode == 'prev':
79
+ x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x)
80
+ else: # 'rand'
81
+ x, _ = model.add_noise(x, r_range[i+1])
82
+ preds.append(x.detach())
83
+ return preds
84
+
85
+ # Model loading
86
+
87
+ vqmodel = get_vae().to(device)
88
+ vqmodel.eval().requires_grad_(False)
89
+
90
+ clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')
91
+ clip_model = clip_model.to(device).eval().requires_grad_(False)
92
+
93
+ def encode(x):
94
+ return vqmodel.model.encode((2 * x - 1))[-1][-1]
95
+
96
+ def decode(img_seq, shape=(32,32)):
97
+ img_seq = img_seq.view(img_seq.shape[0], -1)
98
+ b, n = img_seq.shape
99
+ one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=vqmodel.num_tokens).float()
100
+ z = (one_hot_indices @ vqmodel.model.quantize.embed.weight)
101
+ z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1])
102
+ img = vqmodel.model.decode(z)
103
+ img = (img.clamp(-1., 1.) + 1) * 0.5
104
+ return img
105
+
106
+ state_dict = torch.load(model_id, map_location=device)
107
+ model = DenoiseUNet(num_labels=8192).to(device)
108
+ model.load_state_dict(state_dict)
109
+ model.eval().requires_grad_()
110
+
111
+ # -----
112
+
113
+ def infer(prompt):
114
+ latent_shape = (32, 32)
115
+ tokenized_text = tokenizer.tokenize([prompt] * batch_size).to(device)
116
+ with torch.inference_mode():
117
+ with torch.autocast(device_type="cuda"):
118
+ clip_embeddings = clip_model.encode_text(tokenized_text)
119
+ images = sample(
120
+ model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],
121
+ typical_filtering=True, typical_mass=0.2, typical_min_tokens=1,
122
+ classifier_free_scale=scale, renoise_steps=steps, renoise_mode="start"
123
+ )
124
+ images = decode(images[-1], latent_shape)
125
+ return to_pil(images)
126
+
127
+ css = """
128
+ .gradio-container {
129
+ font-family: 'IBM Plex Sans', sans-serif;
130
+ }
131
+ .gr-button {
132
+ color: white;
133
+ border-color: black;
134
+ background: black;
135
+ }
136
+ input[type='range'] {
137
+ accent-color: black;
138
+ }
139
+ .dark input[type='range'] {
140
+ accent-color: #dfdfdf;
141
+ }
142
+ .container {
143
+ max-width: 730px;
144
+ margin: auto;
145
+ padding-top: 1.5rem;
146
+ }
147
+ #gallery {
148
+ min-height: 22rem;
149
+ margin-bottom: 15px;
150
+ margin-left: auto;
151
+ margin-right: auto;
152
+ border-bottom-right-radius: .5rem !important;
153
+ border-bottom-left-radius: .5rem !important;
154
+ }
155
+ #gallery>div>.h-full {
156
+ min-height: 20rem;
157
+ }
158
+ .details:hover {
159
+ text-decoration: underline;
160
+ }
161
+ .gr-button {
162
+ white-space: nowrap;
163
+ }
164
+ .gr-button:focus {
165
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
166
+ outline: none;
167
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
168
+ --tw-border-opacity: 1;
169
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
170
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
171
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
172
+ --tw-ring-opacity: .5;
173
+ }
174
+ .footer {
175
+ margin-bottom: 45px;
176
+ margin-top: 35px;
177
+ text-align: center;
178
+ border-bottom: 1px solid #e5e5e5;
179
+ }
180
+ .footer>p {
181
+ font-size: .8rem;
182
+ display: inline-block;
183
+ padding: 0 10px;
184
+ transform: translateY(10px);
185
+ background: white;
186
+ }
187
+ .dark .footer {
188
+ border-color: #303030;
189
+ }
190
+ .dark .footer>p {
191
+ background: #0b0f19;
192
+ }
193
+ .acknowledgments h4{
194
+ margin: 1.25em 0 .25em 0;
195
+ font-weight: bold;
196
+ font-size: 115%;
197
+ }
198
+ .animate-spin {
199
+ animation: spin 1s linear infinite;
200
+ }
201
+ @keyframes spin {
202
+ from {
203
+ transform: rotate(0deg);
204
+ }
205
+ to {
206
+ transform: rotate(360deg);
207
+ }
208
+ }
209
+ #share-btn-container {
210
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
211
+ }
212
+ #share-btn {
213
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
214
+ }
215
+ #share-btn * {
216
+ all: unset;
217
+ }
218
+ .gr-form{
219
+ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
220
+ }
221
+ #prompt-container{
222
+ gap: 0;
223
+ }
224
+ """
225
+
226
+ block = gr.Blocks(css=css)
227
+
228
+ with block:
229
+ gr.HTML(
230
+ """
231
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
232
+ <div
233
+ style="
234
+ display: inline-flex;
235
+ align-items: center;
236
+ gap: 0.8rem;
237
+ font-size: 1.75rem;
238
+ "
239
+ >
240
+ <svg
241
+ width="0.65em"
242
+ height="0.65em"
243
+ viewBox="0 0 115 115"
244
+ fill="none"
245
+ xmlns="http://www.w3.org/2000/svg"
246
+ >
247
+ <rect width="23" height="23" fill="white"></rect>
248
+ <rect y="69" width="23" height="23" fill="white"></rect>
249
+ <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
250
+ <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
251
+ <rect x="46" width="23" height="23" fill="white"></rect>
252
+ <rect x="46" y="69" width="23" height="23" fill="white"></rect>
253
+ <rect x="69" width="23" height="23" fill="black"></rect>
254
+ <rect x="69" y="69" width="23" height="23" fill="black"></rect>
255
+ <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
256
+ <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
257
+ <rect x="115" y="46" width="23" height="23" fill="white"></rect>
258
+ <rect x="115" y="115" width="23" height="23" fill="white"></rect>
259
+ <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
260
+ <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
261
+ <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
262
+ <rect x="92" y="69" width="23" height="23" fill="white"></rect>
263
+ <rect x="69" y="46" width="23" height="23" fill="white"></rect>
264
+ <rect x="69" y="115" width="23" height="23" fill="white"></rect>
265
+ <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
266
+ <rect x="46" y="46" width="23" height="23" fill="black"></rect>
267
+ <rect x="46" y="115" width="23" height="23" fill="black"></rect>
268
+ <rect x="46" y="69" width="23" height="23" fill="black"></rect>
269
+ <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
270
+ <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
271
+ <rect x="23" y="69" width="23" height="23" fill="black"></rect>
272
+ </svg>
273
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
274
+ Paella Demo
275
+ </h1>
276
+ </div>
277
+ <p style="margin-bottom: 10px; font-size: 94%">
278
+ Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
279
+ </p>
280
+ </div>
281
+ """
282
+ )
283
+ with gr.Group():
284
+ with gr.Box():
285
+ with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
286
+ text = gr.Textbox(
287
+ label="Enter your prompt",
288
+ show_label=False,
289
+ max_lines=1,
290
+ placeholder="Enter your prompt",
291
+ elem_id="prompt-text-input",
292
+ ).style(
293
+ border=(True, False, True, True),
294
+ rounded=(True, False, False, True),
295
+ container=False,
296
+ )
297
+ btn = gr.Button("Generate image").style(
298
+ margin=False,
299
+ rounded=(False, True, True, False),
300
+ full_width=False,
301
+ )
302
+
303
+ gallery = gr.Gallery(
304
+ label="Generated images", show_label=False, elem_id="gallery"
305
+ ).style(grid=[2], height="auto")
306
+
307
+ text.submit(infer, inputs=text, outputs=gallery)
308
+ btn.click(infer, inputs=text, outputs=gallery)
309
+
310
+ gr.HTML(
311
+ """
312
+ <div class="footer">
313
+ </div>
314
+ <div class="acknowledgments">
315
+ <p><h4>Resources</h4>
316
+ <a href="https://arxiv.org/abs/2211.07292" style="text-decoration: underline;">Paper</a>, <a href="https://github.com/dome272/Paella" style="text-decoration: underline;">official implementation</a>.
317
+ </p>
318
+ <p><h4>LICENSE</h4>
319
+ <a href="https://github.com/dome272/Paella/blob/main/LICENSE" style="text-decoration: underline;">MIT</a>.
320
+ </p>
321
+ <p><h4>Biases and content acknowledgment</h4>
322
+ Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography and violence. The model was trained on 600 million images from the improved <a href="https://laion.ai/blog/laion-5b/" style="text-decoration: underline;" target="_blank">LAION-5B aesthetic</a> dataset, which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content) and is meant for research purposes.
323
+ </p>
324
+ </div>
325
+ """
326
+ )
327
+
328
+ block.launch()
modules.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ModulatedLayerNorm(nn.Module):
8
+ def __init__(self, num_features, eps=1e-6, channels_first=True):
9
+ super().__init__()
10
+ self.ln = nn.LayerNorm(num_features, eps=eps)
11
+ self.gamma = nn.Parameter(torch.randn(1, 1, 1))
12
+ self.beta = nn.Parameter(torch.randn(1, 1, 1))
13
+ self.channels_first = channels_first
14
+
15
+ def forward(self, x, w=None):
16
+ x = x.permute(0, 2, 3, 1) if self.channels_first else x
17
+ if w is None:
18
+ x = self.ln(x)
19
+ else:
20
+ x = self.gamma * w * self.ln(x) + self.beta * w
21
+ x = x.permute(0, 3, 1, 2) if self.channels_first else x
22
+ return x
23
+
24
+
25
+ class ResBlock(nn.Module):
26
+ def __init__(self, c, c_hidden, c_cond=0, c_skip=0, scaler=None, layer_scale_init_value=1e-6):
27
+ super().__init__()
28
+ self.depthwise = nn.Sequential(
29
+ nn.ReflectionPad2d(1),
30
+ nn.Conv2d(c, c, kernel_size=3, groups=c)
31
+ )
32
+ self.ln = ModulatedLayerNorm(c, channels_first=False)
33
+ self.channelwise = nn.Sequential(
34
+ nn.Linear(c + c_skip, c_hidden),
35
+ nn.GELU(),
36
+ nn.Linear(c_hidden, c),
37
+ )
38
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c), requires_grad=True) if layer_scale_init_value > 0 else None
39
+ self.scaler = scaler
40
+ if c_cond > 0:
41
+ self.cond_mapper = nn.Linear(c_cond, c)
42
+
43
+ def forward(self, x, s=None, skip=None):
44
+ res = x
45
+ x = self.depthwise(x)
46
+ if s is not None:
47
+ if s.size(2) == s.size(3) == 1:
48
+ s = s.expand(-1, -1, x.size(2), x.size(3))
49
+ elif s.size(2) != x.size(2) or s.size(3) != x.size(3):
50
+ s = nn.functional.interpolate(s, size=x.shape[-2:], mode='bilinear')
51
+ s = self.cond_mapper(s.permute(0, 2, 3, 1))
52
+ # s = self.cond_mapper(s.permute(0, 2, 3, 1))
53
+ # if s.size(1) == s.size(2) == 1:
54
+ # s = s.expand(-1, x.size(2), x.size(3), -1)
55
+ x = self.ln(x.permute(0, 2, 3, 1), s)
56
+ if skip is not None:
57
+ x = torch.cat([x, skip.permute(0, 2, 3, 1)], dim=-1)
58
+ x = self.channelwise(x)
59
+ x = self.gamma * x if self.gamma is not None else x
60
+ x = res + x.permute(0, 3, 1, 2)
61
+ if self.scaler is not None:
62
+ x = self.scaler(x)
63
+ return x
64
+
65
+
66
+ class DenoiseUNet(nn.Module):
67
+ def __init__(self, num_labels, c_hidden=1280, c_clip=1024, c_r=64, down_levels=[4, 8, 16], up_levels=[16, 8, 4]):
68
+ super().__init__()
69
+ self.num_labels = num_labels
70
+ self.c_r = c_r
71
+ self.down_levels = down_levels
72
+ self.up_levels = up_levels
73
+ c_levels = [c_hidden // (2 ** i) for i in reversed(range(len(down_levels)))]
74
+ self.embedding = nn.Embedding(num_labels, c_levels[0])
75
+
76
+ # DOWN BLOCKS
77
+ self.down_blocks = nn.ModuleList()
78
+ for i, num_blocks in enumerate(down_levels):
79
+ blocks = []
80
+ if i > 0:
81
+ blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
82
+ for _ in range(num_blocks):
83
+ block = ResBlock(c_levels[i], c_levels[i] * 4, c_clip + c_r)
84
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(down_levels))
85
+ blocks.append(block)
86
+ self.down_blocks.append(nn.ModuleList(blocks))
87
+
88
+ # UP BLOCKS
89
+ self.up_blocks = nn.ModuleList()
90
+ for i, num_blocks in enumerate(up_levels):
91
+ blocks = []
92
+ for j in range(num_blocks):
93
+ block = ResBlock(c_levels[len(c_levels) - 1 - i], c_levels[len(c_levels) - 1 - i] * 4, c_clip + c_r,
94
+ c_levels[len(c_levels) - 1 - i] if (j == 0 and i > 0) else 0)
95
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(up_levels))
96
+ blocks.append(block)
97
+ if i < len(up_levels) - 1:
98
+ blocks.append(
99
+ nn.ConvTranspose2d(c_levels[len(c_levels) - 1 - i], c_levels[len(c_levels) - 2 - i], kernel_size=4, stride=2, padding=1))
100
+ self.up_blocks.append(nn.ModuleList(blocks))
101
+
102
+ self.clf = nn.Conv2d(c_levels[0], num_labels, kernel_size=1)
103
+
104
+ def gamma(self, r):
105
+ return (r * torch.pi / 2).cos()
106
+
107
+ def add_noise(self, x, r, random_x=None):
108
+ r = self.gamma(r)[:, None, None]
109
+ mask = torch.bernoulli(r * torch.ones_like(x), )
110
+ mask = mask.round().long()
111
+ if random_x is None:
112
+ random_x = torch.randint_like(x, 0, self.num_labels)
113
+ x = x * (1 - mask) + random_x * mask
114
+ return x, mask
115
+
116
+ def gen_r_embedding(self, r, max_positions=10000):
117
+ dtype = r.dtype
118
+ r = self.gamma(r) * max_positions
119
+ half_dim = self.c_r // 2
120
+ emb = math.log(max_positions) / (half_dim - 1)
121
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
122
+ emb = r[:, None] * emb[None, :]
123
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
124
+ if self.c_r % 2 == 1: # zero pad
125
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
126
+ return emb.to(dtype)
127
+
128
+ def _down_encode_(self, x, s):
129
+ level_outputs = []
130
+ for i, blocks in enumerate(self.down_blocks):
131
+ for block in blocks:
132
+ if isinstance(block, ResBlock):
133
+ # s_level = s[:, 0]
134
+ # s = s[:, 1:]
135
+ x = block(x, s)
136
+ else:
137
+ x = block(x)
138
+ level_outputs.insert(0, x)
139
+ return level_outputs
140
+
141
+ def _up_decode(self, level_outputs, s):
142
+ x = level_outputs[0]
143
+ for i, blocks in enumerate(self.up_blocks):
144
+ for j, block in enumerate(blocks):
145
+ if isinstance(block, ResBlock):
146
+ # s_level = s[:, 0]
147
+ # s = s[:, 1:]
148
+ if i > 0 and j == 0:
149
+ x = block(x, s, level_outputs[i])
150
+ else:
151
+ x = block(x, s)
152
+ else:
153
+ x = block(x)
154
+ return x
155
+
156
+ def forward(self, x, c, r): # r is a uniform value between 0 and 1
157
+ r_embed = self.gen_r_embedding(r)
158
+ x = self.embedding(x).permute(0, 3, 1, 2)
159
+ if len(c.shape) == 2:
160
+ s = torch.cat([c, r_embed], dim=-1)[:, :, None, None]
161
+ else:
162
+ r_embed = r_embed[:, :, None, None].expand(-1, -1, c.size(2), c.size(3))
163
+ s = torch.cat([c, r_embed], dim=1)
164
+ level_outputs = self._down_encode_(x, s)
165
+ x = self._up_decode(level_outputs, s)
166
+ x = self.clf(x)
167
+ return x
168
+
169
+
170
+ if __name__ == '__main__':
171
+ device = "cuda"
172
+ model = DenoiseUNet(1024).to(device)
173
+ print(sum([p.numel() for p in model.parameters()]))
174
+ x = torch.randint(0, 1024, (1, 32, 32)).long().to(device)
175
+ c = torch.randn((1, 1024)).to(device)
176
+ r = torch.rand(1).to(device)
177
+ model(x, c, r)
178
+