pcuenq HF staff commited on
Commit
cab8a49
1 Parent(s): 3875a6e

Add copy of github repo

Browse files
Paella/src/modules.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ from torch import nn
5
+
6
+
7
+ class Attention2D(nn.Module):
8
+ def __init__(self, c, nhead, dropout=0.0):
9
+ super().__init__()
10
+ self.attn = torch.nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
11
+
12
+ def forward(self, x, kv, self_attn=False):
13
+ orig_shape = x.shape
14
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1)
15
+ if self_attn:
16
+ kv = torch.cat([x, kv], dim=1)
17
+ x = self.attn(x, kv, kv, need_weights=False)[0]
18
+ x = x.permute(0, 2, 1).view(*orig_shape)
19
+ return x
20
+
21
+
22
+ class LayerNorm2d(nn.LayerNorm):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+
26
+ def forward(self, x):
27
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
28
+
29
+
30
+ class GlobalResponseNorm(nn.Module):
31
+ "Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
32
+ def __init__(self, dim):
33
+ super().__init__()
34
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
35
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
36
+
37
+ def forward(self, x):
38
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
39
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
40
+ return self.gamma * (x * Nx) + self.beta + x
41
+
42
+
43
+ class ResBlock(nn.Module):
44
+ def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
45
+ super().__init__()
46
+ self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
47
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
48
+ self.channelwise = nn.Sequential(
49
+ nn.Linear(c, c * 4),
50
+ nn.GELU(),
51
+ GlobalResponseNorm(c * 4),
52
+ nn.Dropout(dropout),
53
+ nn.Linear(c * 4, c)
54
+ )
55
+
56
+ def forward(self, x, x_skip=None):
57
+ x_res = x
58
+ if x_skip is not None:
59
+ x = torch.cat([x, x_skip], dim=1)
60
+ x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
61
+ x = self.channelwise(x).permute(0, 3, 1, 2)
62
+ return x + x_res
63
+
64
+
65
+ class AttnBlock(nn.Module):
66
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
67
+ super().__init__()
68
+ self.self_attn = self_attn
69
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
70
+ self.attention = Attention2D(c, nhead, dropout)
71
+ self.kv_mapper = nn.Sequential(
72
+ nn.SiLU(),
73
+ nn.Linear(c_cond, c)
74
+ )
75
+
76
+ def forward(self, x, kv):
77
+ kv = self.kv_mapper(kv)
78
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
79
+ return x
80
+
81
+
82
+ class FeedForwardBlock(nn.Module):
83
+ def __init__(self, c, dropout=0.0):
84
+ super().__init__()
85
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
86
+ self.channelwise = nn.Sequential(
87
+ nn.Linear(c, c * 4),
88
+ nn.GELU(),
89
+ GlobalResponseNorm(c * 4),
90
+ nn.Dropout(dropout),
91
+ nn.Linear(c * 4, c)
92
+ )
93
+
94
+ def forward(self, x):
95
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
96
+ return x
97
+
98
+
99
+ class TimestepBlock(nn.Module):
100
+ def __init__(self, c, c_timestep):
101
+ super().__init__()
102
+ self.mapper = nn.Linear(c_timestep, c * 2)
103
+
104
+ def forward(self, x, t):
105
+ a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
106
+ return x * (1 + a) + b
107
+
108
+
109
+ class Paella(nn.Module):
110
+ def __init__(self, c_in=256, c_out=256, num_labels=8192, c_r=64, patch_size=2, c_cond=1024,
111
+ c_hidden=[640, 1280, 1280], nhead=[-1, 16, 16], blocks=[6, 16, 6], level_config=['CT', 'CTA', 'CTA'],
112
+ clip_embd=1024, byt5_embd=1536, clip_seq_len=4, kernel_size=3, dropout=0.1, self_attn=True):
113
+ super().__init__()
114
+ self.c_r = c_r
115
+ self.c_cond = c_cond
116
+ self.num_labels = num_labels
117
+ if not isinstance(dropout, list):
118
+ dropout = [dropout] * len(c_hidden)
119
+
120
+ # CONDITIONING
121
+ self.byt5_mapper = nn.Linear(byt5_embd, c_cond)
122
+ self.clip_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len)
123
+ self.clip_image_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len)
124
+ self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
125
+
126
+ self.in_mapper = nn.Sequential(
127
+ nn.Embedding(num_labels, c_in),
128
+ nn.LayerNorm(c_in, elementwise_affine=False, eps=1e-6)
129
+ )
130
+ self.embedding = nn.Sequential(
131
+ nn.PixelUnshuffle(patch_size),
132
+ nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
133
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
134
+ )
135
+
136
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0):
137
+ if block_type == 'C':
138
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
139
+ elif block_type == 'A':
140
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
141
+ elif block_type == 'F':
142
+ return FeedForwardBlock(c_hidden, dropout=dropout)
143
+ elif block_type == 'T':
144
+ return TimestepBlock(c_hidden, c_r)
145
+ else:
146
+ raise Exception(f'Block type {block_type} not supported')
147
+
148
+ # DOWN BLOCKS
149
+ self.down_blocks = nn.ModuleList()
150
+ for i in range(len(c_hidden)):
151
+ down_block = nn.ModuleList()
152
+ if i > 0:
153
+ down_block.append(nn.Sequential(
154
+ LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
155
+ nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
156
+ ))
157
+ for _ in range(blocks[i]):
158
+ for block_type in level_config[i]:
159
+ down_block.append(get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i]))
160
+ self.down_blocks.append(down_block)
161
+
162
+ # UP BLOCKS
163
+ self.up_blocks = nn.ModuleList()
164
+ for i in reversed(range(len(c_hidden))):
165
+ up_block = nn.ModuleList()
166
+ for j in range(blocks[i]):
167
+ for k, block_type in enumerate(level_config[i]):
168
+ up_block.append(get_block(block_type, c_hidden[i], nhead[i],
169
+ c_skip=c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0,
170
+ dropout=dropout[i]))
171
+ if i > 0:
172
+ up_block.append(nn.Sequential(
173
+ LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
174
+ nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
175
+ ))
176
+ self.up_blocks.append(up_block)
177
+
178
+ # OUTPUT
179
+ self.clf = nn.Sequential(
180
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
181
+ nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
182
+ nn.PixelShuffle(patch_size),
183
+ )
184
+ self.out_mapper = nn.Sequential(
185
+ LayerNorm2d(c_out, elementwise_affine=False, eps=1e-6),
186
+ nn.Conv2d(c_out, num_labels, kernel_size=1, bias=False)
187
+ )
188
+
189
+ # --- WEIGHT INIT ---
190
+ self.apply(self._init_weights) # General init
191
+ nn.init.normal_(self.byt5_mapper.weight, std=0.02)
192
+ nn.init.normal_(self.clip_mapper.weight, std=0.02)
193
+ nn.init.normal_(self.clip_image_mapper.weight, std=0.02)
194
+ torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02)
195
+ nn.init.constant_(self.clf[1].weight, 0)
196
+ nn.init.normal_(self.in_mapper[0].weight, std=np.sqrt(1 / num_labels))
197
+ self.out_mapper[-1].weight.data = self.in_mapper[0].weight.data[:, :, None, None].clone()
198
+
199
+ for level_block in self.down_blocks + self.up_blocks:
200
+ for block in level_block:
201
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
202
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks))
203
+ elif isinstance(block, TimestepBlock):
204
+ nn.init.constant_(block.mapper.weight, 0)
205
+
206
+ def _init_weights(self, m):
207
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
208
+ torch.nn.init.xavier_uniform_(m.weight)
209
+ if m.bias is not None:
210
+ nn.init.constant_(m.bias, 0)
211
+
212
+ def gen_r_embedding(self, r, max_positions=10000):
213
+ r = r * max_positions
214
+ half_dim = self.c_r // 2
215
+ emb = math.log(max_positions) / (half_dim - 1)
216
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
217
+ emb = r[:, None] * emb[None, :]
218
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
219
+ if self.c_r % 2 == 1:
220
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
221
+ return emb
222
+
223
+ def gen_c_embeddings(self, byt5, clip, clip_image):
224
+ seq = self.byt5_mapper(byt5)
225
+ if clip is not None:
226
+ clip = self.clip_mapper(clip).view(clip.size(0), -1, self.c_cond)
227
+ seq = torch.cat([seq, clip], dim=1)
228
+ if clip_image is not None:
229
+ clip_image = self.clip_image_mapper(clip_image).view(clip_image.size(0), -1, self.c_cond)
230
+ seq = torch.cat([seq, clip_image], dim=1)
231
+ seq = self.seq_norm(seq)
232
+ return seq
233
+
234
+ def _down_encode(self, x, r_embed, c_embed):
235
+ level_outputs = []
236
+ for down_block in self.down_blocks:
237
+ for block in down_block:
238
+ if isinstance(block, ResBlock):
239
+ x = block(x)
240
+ elif isinstance(block, AttnBlock):
241
+ x = block(x, c_embed)
242
+ elif isinstance(block, TimestepBlock):
243
+ x = block(x, r_embed)
244
+ else:
245
+ x = block(x)
246
+ level_outputs.insert(0, x)
247
+ return level_outputs
248
+
249
+ def _up_decode(self, level_outputs, r_embed, c_embed):
250
+ x = level_outputs[0]
251
+ for i, up_block in enumerate(self.up_blocks):
252
+ for j, block in enumerate(up_block):
253
+ if isinstance(block, ResBlock):
254
+ x = block(x, level_outputs[i] if j == 0 and i > 0 else None)
255
+ elif isinstance(block, AttnBlock):
256
+ x = block(x, c_embed)
257
+ elif isinstance(block, TimestepBlock):
258
+ x = block(x, r_embed)
259
+ else:
260
+ x = block(x)
261
+ return x
262
+
263
+ def forward(self, x, r, byt5, clip=None, clip_image=None, x_cat=None):
264
+ if x_cat is not None:
265
+ x = torch.cat([x, x_cat], dim=1)
266
+ # Process the conditioning embeddings
267
+ r_embed = self.gen_r_embedding(r)
268
+ c_embed = self.gen_c_embeddings(byt5, clip, clip_image)
269
+
270
+ # Model Blocks
271
+ x = self.embedding(self.in_mapper(x).permute(0, 3, 1, 2))
272
+ level_outputs = self._down_encode(x, r_embed, c_embed)
273
+ x = self._up_decode(level_outputs, r_embed, c_embed)
274
+ x = self.out_mapper(self.clf(x))
275
+ return x
276
+
277
+ def add_noise(self, x, t, mask=None, random_x=None):
278
+ if mask is None:
279
+ mask = (torch.rand_like(x.float()) <= t[:, None, None]).long()
280
+ if random_x is None:
281
+ random_x = torch.randint_like(x, 0, self.num_labels)
282
+ x = x * (1 - mask) + random_x * mask
283
+ return x, mask
Paella/src/train.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from modules import Paella
6
+ from torch import nn, optim
7
+ from warmup_scheduler import GradualWarmupScheduler
8
+ from utils import get_dataloader, load_conditional_models
9
+
10
+ steps = 100_000
11
+ warmup_updates = 10000
12
+ batch_size = 16
13
+ checkpoint_frequency = 2000
14
+ lr = 1e-4
15
+
16
+ train_device = "cuda"
17
+ dataset_path = ""
18
+ byt5_model_name = "google/byt5-xl"
19
+ vqmodel_path = ""
20
+ run_name = "Paella-ByT5-XL-v1"
21
+ output_path = "output"
22
+ checkpoint_path = f"{run_name}.pt"
23
+
24
+
25
+ def train():
26
+ os.makedirs(output_path, exist_ok=True)
27
+ device = torch.device(train_device)
28
+
29
+ dataloader = get_dataloader(dataset_path, batch_size=batch_size)
30
+ checkpoint = torch.load(checkpoint_path, map_location=device) if os.path.exists(checkpoint_path) else None
31
+
32
+ model = Paella(byt5_embd=2560).to(device)
33
+ vqgan, (byt5_tokenizer, byt5) = load_conditional_models(byt5_model_name, vqmodel_path, device)
34
+ optimizer = optim.AdamW(model.parameters(), lr=lr)
35
+ scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_updates)
36
+ criterion = nn.CrossEntropyLoss(label_smoothing=0.1, reduction='none')
37
+
38
+ start_iter = 1
39
+ if checkpoint is not None:
40
+ model.load_state_dict(checkpoint['state_dict'])
41
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
42
+ scheduler.last_epoch = checkpoint['scheduler_last_step']
43
+ start_iter = checkpoint['scheduler_last_step'] + 1
44
+ del checkpoint
45
+
46
+ pbar = tqdm(range(start_iter, steps+1))
47
+ model.train()
48
+ for i, (images, captions) in enumerate(dataloader):
49
+ images = images.to(device)
50
+
51
+ with torch.no_grad():
52
+ if np.random.rand() < 0.05:
53
+ byt5_captions = [''] * len(captions)
54
+ else:
55
+ byt5_captions = captions
56
+ byt5_tokens = byt5_tokenizer(byt5_captions, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device)
57
+ byt_embeddings = byt5(input_ids=byt5_tokens).last_hidden_state
58
+
59
+ t = (1-torch.rand(images.size(0), device=device))
60
+ latents = vqgan.encode(images)[2]
61
+ noised_latents, _ = model.add_noise(latents, t)
62
+
63
+ pred = model(noised_latents, t, byt_embeddings)
64
+ loss = criterion(pred, latents)
65
+
66
+ loss.backward()
67
+ grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)
68
+ scheduler.step()
69
+ optimizer.zero_grad()
70
+
71
+ acc = (pred.argmax(1) == latents).float().mean()
72
+
73
+ pbar.set_postfix({'bs': images.size(0), 'loss': loss.item(), 'acc': acc.item(), 'grad_norm': grad_norm.item(), 'lr': optimizer.param_groups[0]['lr'], 'total_steps': scheduler.last_epoch})
74
+
75
+ if i % checkpoint_frequency == 0:
76
+ torch.save({'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_last_step': scheduler.last_epoch, 'iter' : i}, checkpoint_path)
77
+
78
+
79
+ if __name__ == '__main__':
80
+ train()
Paella/src/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from vqgan import VQModel
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from transformers import T5EncoderModel, AutoTokenizer
6
+
7
+ transforms = torchvision.transforms.Compose([
8
+ torchvision.transforms.ToTensor(),
9
+ torchvision.transforms.Resize(256),
10
+ torchvision.transforms.RandomCrop(256),
11
+ ])
12
+
13
+
14
+ class YOUR_DATASET(Dataset):
15
+ def __init__(self, dataset_path):
16
+ pass
17
+
18
+
19
+ def get_dataloader(dataset_path, batch_size):
20
+ dataset = YOUR_DATASET(dataset_path)
21
+ return DataLoader(dataset, batch_size=batch_size, num_workers=8, pin_memory=True)
22
+
23
+
24
+ def load_conditional_models(byt5_model_name, vqgan_path, device):
25
+ vqgan = VQModel().to(device)
26
+ vqgan.load_state_dict(torch.load(vqgan_path, map_location=device)['state_dict'])
27
+ vqgan.eval().requires_grad_(False)
28
+
29
+ byt5 = T5EncoderModel.from_pretrained(byt5_model_name).to(device).eval().requires_grad_(False)
30
+ byt5_tokenizer = AutoTokenizer.from_pretrained(byt5_model_name)
31
+
32
+ return vqgan, (byt5_tokenizer, byt5)
33
+
34
+
35
+ def sample(model, model_inputs, latent_shape, unconditional_inputs=None, steps=12, renoise_steps=11, temperature=(1.0, 0.2), cfg=8.0, t_start=1.0, t_end=0.0, device="cuda"):
36
+ with torch.inference_mode():
37
+ sampled = torch.randint(0, model.num_labels, size=latent_shape, device=device)
38
+ init_noise = sampled.clone()
39
+ t_list = torch.linspace(t_start, t_end, steps+1)
40
+ temperatures = torch.linspace(temperature[0], temperature[1], steps)
41
+ for i, t in enumerate(t_list[:steps]):
42
+ t = torch.ones(latent_shape[0], device=device) * t
43
+
44
+ logits = model(sampled, t, **model_inputs)
45
+ if cfg:
46
+ logits = logits * cfg + model(sampled, t, **unconditional_inputs) * (1-cfg)
47
+ scores = logits.div(temperatures[i]).softmax(dim=1)
48
+
49
+ sampled = scores.permute(0, 2, 3, 1).reshape(-1, logits.size(1))
50
+ sampled = torch.multinomial(sampled, 1)[:, 0].view(logits.size(0), *logits.shape[2:])
51
+
52
+ if i < renoise_steps:
53
+ t_next = torch.ones(latent_shape[0], device=device) * t_list[i+1]
54
+ sampled = model.add_noise(sampled, t_next, random_x=init_noise)[0]
55
+ return sampled
Paella/src/vqgan.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchtools.nn import VectorQuantize
4
+
5
+
6
+ class ResBlock(nn.Module):
7
+ def __init__(self, c, c_hidden):
8
+ super().__init__()
9
+ # depthwise/attention
10
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
11
+ self.depthwise = nn.Sequential(
12
+ nn.ReplicationPad2d(1),
13
+ nn.Conv2d(c, c, kernel_size=3, groups=c)
14
+ )
15
+
16
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
17
+ self.channelwise = nn.Sequential(
18
+ nn.Linear(c, c_hidden),
19
+ nn.GELU(),
20
+ nn.Linear(c_hidden, c),
21
+ )
22
+
23
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
24
+
25
+ def _basic_init(module):
26
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
27
+ torch.nn.init.xavier_uniform_(module.weight)
28
+ if module.bias is not None:
29
+ nn.init.constant_(module.bias, 0)
30
+
31
+ self.apply(_basic_init)
32
+
33
+ def _norm(self, x, norm):
34
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
35
+
36
+ def forward(self, x):
37
+ mods = self.gammas
38
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
39
+ x = x + self.depthwise(x_temp) * mods[2]
40
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
41
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
42
+ return x
43
+
44
+
45
+ class VQModel(nn.Module):
46
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
47
+ scale_factor=0.3764): # 1.0
48
+ super().__init__()
49
+ self.c_latent = c_latent
50
+ self.scale_factor = scale_factor
51
+ c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
52
+
53
+ # Encoder blocks
54
+ self.in_block = nn.Sequential(
55
+ nn.PixelUnshuffle(2),
56
+ nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
57
+ )
58
+ down_blocks = []
59
+ for i in range(levels):
60
+ if i > 0:
61
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
62
+ block = ResBlock(c_levels[i], c_levels[i] * 4)
63
+ down_blocks.append(block)
64
+ down_blocks.append(nn.Sequential(
65
+ nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
66
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
67
+ ))
68
+ self.down_blocks = nn.Sequential(*down_blocks)
69
+
70
+ self.codebook_size = codebook_size
71
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
72
+
73
+ # Decoder blocks
74
+ up_blocks = [nn.Sequential(
75
+ nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
76
+ )]
77
+ for i in range(levels):
78
+ for j in range(bottleneck_blocks if i == 0 else 1):
79
+ block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
80
+ up_blocks.append(block)
81
+ if i < levels - 1:
82
+ up_blocks.append(
83
+ nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
84
+ padding=1))
85
+ self.up_blocks = nn.Sequential(*up_blocks)
86
+ self.out_block = nn.Sequential(
87
+ nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
88
+ nn.PixelShuffle(2),
89
+ )
90
+
91
+ def encode(self, x):
92
+ x = self.in_block(x)
93
+ x = self.down_blocks(x)
94
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
95
+ return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
96
+
97
+ def decode(self, x):
98
+ x = x * self.scale_factor
99
+ x = self.up_blocks(x)
100
+ x = self.out_block(x)
101
+ return x
102
+
103
+ def decode_indices(self, x):
104
+ x = self.vquantizer.idx2vq(x, dim=1)
105
+ x = self.up_blocks(x)
106
+ x = self.out_block(x)
107
+ return x
108
+
109
+ def forward(self, x, quantize=False):
110
+ qe, x, _, vq_loss = self.encode(x, quantize)
111
+ x = self.decode(qe)
112
+ return x, vq_loss
113
+
114
+
115
+ class Discriminator(nn.Module):
116
+ def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
117
+ super().__init__()
118
+ d = max(depth - 3, 3)
119
+ layers = [
120
+ nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
121
+ nn.LeakyReLU(0.2),
122
+ ]
123
+ for i in range(depth - 1):
124
+ c_in = c_hidden // (2 ** max((d - i), 0))
125
+ c_out = c_hidden // (2 ** max((d - 1 - i), 0))
126
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
127
+ layers.append(nn.InstanceNorm2d(c_out))
128
+ layers.append(nn.LeakyReLU(0.2))
129
+ self.encoder = nn.Sequential(*layers)
130
+ self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
131
+ self.logits = nn.Sigmoid()
132
+
133
+ def forward(self, x, cond=None):
134
+ x = self.encoder(x)
135
+ if cond is not None:
136
+ cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
137
+ x = torch.cat([x, cond], dim=1)
138
+ x = self.shuffle(x)
139
+ x = self.logits(x)
140
+ return x
Paella/utils/alter_attention.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class CustomMultiheadAttention(nn.MultiheadAttention):
5
+ def forward(self, *args, attn_weights=None, **kwargs):
6
+ q, k, v = args[:3]
7
+ need_weights = kwargs.get('need_weights', False)
8
+
9
+ w = self.in_proj_weight.chunk(3, dim=0)
10
+ b = self.in_proj_bias.chunk(3, dim=0)
11
+
12
+ if not self.batch_first:
13
+ q, k, v = q.permute(0, 1), k.permute(0, 1), v.permute(0, 1)
14
+
15
+ q = nn.functional.linear(q, w[0], bias=b[0]).view(q.size(0), q.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
16
+ k = nn.functional.linear(k, w[1], bias=b[1]).view(k.size(0), k.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
17
+ v = nn.functional.linear(v, w[2], bias=b[2]).view(v.size(0), v.size(1), self.num_heads, -1).permute(0, 2, 1, 3)
18
+
19
+ scores = (q @ k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
20
+ attention = scores.softmax(dim=-1)
21
+ # print(attention.shape)
22
+
23
+ if attn_weights is not None:
24
+ # print("q ", q.shape)
25
+ # print("k ", k.shape)
26
+ weights = torch.ones((attention.shape[2], attention.shape[3])).to(q.device)
27
+ # print("Weights: ", weights.shape)
28
+ attn_weights = attn_weights.expand(attention.shape[2], attn_weights.shape[0])
29
+ weights[-attn_weights.shape[0]:, -attn_weights.shape[1]:] = attn_weights
30
+ # print(f"{-attn_weights.shape[0]}, {-attn_weights.shape[1]}")
31
+ attn_weights = weights.clone()
32
+ # print("Attn Weights: ", weights.shape)
33
+ # print("weight", attn_weights.shape)
34
+ attention = attention * attn_weights
35
+
36
+ x = attention @ v
37
+ x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1)
38
+ x = self.out_proj(x)
39
+
40
+ if not self.batch_first:
41
+ x = x.permute(0, 1)
42
+
43
+ return (x, attention if need_weights else None)
44
+
45
+ def replace_attention_layers(model):
46
+ for n, module in model.named_children():
47
+ if len(list(module.children())) > 0:
48
+ replace_attention_layers(module)
49
+
50
+ if isinstance(module, nn.MultiheadAttention):
51
+ new_module = CustomMultiheadAttention(module.embed_dim, module.num_heads, dropout=module.dropout, bias=True, batch_first=module.batch_first)
52
+ new_module.load_state_dict(module.state_dict())
53
+ setattr(model, n, new_module)
Paella/utils/modules.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ import math
5
+
6
+
7
+ class Attention2D(nn.Module):
8
+ def __init__(self, c, nhead, dropout=0.0):
9
+ super().__init__()
10
+ self.attn = torch.nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
11
+
12
+ def forward(self, x, kv, self_attn=False, **kwargs):
13
+ orig_shape = x.shape
14
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
15
+ if self_attn:
16
+ kv = torch.cat([x, kv], dim=1)
17
+ x = self.attn(x, kv, kv, need_weights=False, **kwargs)[0]
18
+ x = x.permute(0, 2, 1).view(*orig_shape)
19
+ return x
20
+
21
+
22
+ class LayerNorm2d(nn.LayerNorm):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+
26
+ def forward(self, x):
27
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
28
+
29
+
30
+ class GlobalResponseNorm(nn.Module):
31
+ "Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
32
+ def __init__(self, dim):
33
+ super().__init__()
34
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
35
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
36
+
37
+ def forward(self, x):
38
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
39
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
40
+ return self.gamma * (x * Nx) + self.beta + x
41
+
42
+
43
+ class ResBlock(nn.Module):
44
+ def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
45
+ super().__init__()
46
+ self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
47
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
48
+ self.channelwise = nn.Sequential(
49
+ nn.Linear(c, c * 4),
50
+ nn.GELU(),
51
+ GlobalResponseNorm(c * 4),
52
+ nn.Dropout(dropout),
53
+ nn.Linear(c * 4, c)
54
+ )
55
+
56
+ def forward(self, x, x_skip=None):
57
+ x_res = x
58
+ if x_skip is not None:
59
+ x = torch.cat([x, x_skip], dim=1)
60
+ x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
61
+ x = self.channelwise(x).permute(0, 3, 1, 2)
62
+ return x + x_res
63
+
64
+
65
+ class AttnBlock(nn.Module):
66
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
67
+ super().__init__()
68
+ self.self_attn = self_attn
69
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
70
+ self.attention = Attention2D(c, nhead, dropout)
71
+ self.kv_mapper = nn.Sequential(
72
+ nn.SiLU(),
73
+ nn.Linear(c_cond, c)
74
+ )
75
+
76
+ def forward(self, x, kv, **kwargs):
77
+ kv = self.kv_mapper(kv)
78
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, **kwargs)
79
+ return x
80
+
81
+
82
+ class FeedForwardBlock(nn.Module):
83
+ def __init__(self, c, dropout=0.0):
84
+ super().__init__()
85
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
86
+ self.channelwise = nn.Sequential(
87
+ nn.Linear(c, c * 4),
88
+ nn.GELU(),
89
+ GlobalResponseNorm(c * 4),
90
+ nn.Dropout(dropout),
91
+ nn.Linear(c * 4, c)
92
+ )
93
+
94
+ def forward(self, x):
95
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
96
+ return x
97
+
98
+
99
+ class TimestepBlock(nn.Module):
100
+ def __init__(self, c, c_timestep):
101
+ super().__init__()
102
+ self.mapper = nn.Linear(c_timestep, c * 2)
103
+
104
+ def forward(self, x, t):
105
+ a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
106
+ return x * (1 + a) + b
107
+
108
+
109
+ class Paella(nn.Module):
110
+ def __init__(self, c_in=256, c_out=256, num_labels=8192, c_r=64, patch_size=2, c_cond=1024,
111
+ c_hidden=[640, 1280, 1280], nhead=[-1, 16, 16], blocks=[6, 16, 6], level_config=['CT', 'CTA', 'CTA'],
112
+ clip_embd=1024, byt5_embd=1536, clip_seq_len=4, kernel_size=3, dropout=0.1, self_attn=True):
113
+ super().__init__()
114
+ self.c_r = c_r
115
+ self.c_cond = c_cond
116
+ self.num_labels = num_labels
117
+ if not isinstance(dropout, list):
118
+ dropout = [dropout] * len(c_hidden)
119
+
120
+ # CONDITIONING
121
+ self.byt5_mapper = nn.Linear(byt5_embd, c_cond)
122
+ self.clip_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len)
123
+ self.clip_image_mapper = nn.Linear(clip_embd, c_cond * clip_seq_len)
124
+ self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
125
+
126
+ self.in_mapper = nn.Sequential(
127
+ nn.Embedding(num_labels, c_in),
128
+ nn.LayerNorm(c_in, elementwise_affine=False, eps=1e-6)
129
+ )
130
+ self.embedding = nn.Sequential(
131
+ nn.PixelUnshuffle(patch_size),
132
+ nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
133
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
134
+ )
135
+
136
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0):
137
+ if block_type == 'C':
138
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
139
+ elif block_type == 'A':
140
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
141
+ elif block_type == 'F':
142
+ return FeedForwardBlock(c_hidden, dropout=dropout)
143
+ elif block_type == 'T':
144
+ return TimestepBlock(c_hidden, c_r)
145
+ else:
146
+ raise Exception(f'Block type {block_type} not supported')
147
+
148
+ # DOWN BLOCK
149
+ self.down_blocks = nn.ModuleList()
150
+ for i in range(len(c_hidden)):
151
+ down_block = nn.ModuleList()
152
+ if i > 0:
153
+ down_block.append(nn.Sequential(
154
+ LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
155
+ nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
156
+ ))
157
+ for _ in range(blocks[i]):
158
+ for block_type in level_config[i]:
159
+ down_block.append(get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i]))
160
+ self.down_blocks.append(down_block)
161
+
162
+ # UP BLOCKS
163
+ self.up_blocks = nn.ModuleList()
164
+ for i in reversed(range(len(c_hidden))):
165
+ up_block = nn.ModuleList()
166
+ for j in range(blocks[i]):
167
+ for k, block_type in enumerate(level_config[i]):
168
+ up_block.append(get_block(block_type, c_hidden[i], nhead[i],
169
+ c_skip=c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0,
170
+ dropout=dropout[i]))
171
+ if i > 0:
172
+ up_block.append(nn.Sequential(
173
+ LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
174
+ nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
175
+ ))
176
+ self.up_blocks.append(up_block)
177
+
178
+ # OUTPUT
179
+ self.clf = nn.Sequential(
180
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
181
+ nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
182
+ nn.PixelShuffle(patch_size),
183
+ )
184
+ self.out_mapper = nn.Sequential(
185
+ LayerNorm2d(c_out, elementwise_affine=False, eps=1e-6),
186
+ nn.Conv2d(c_out, num_labels, kernel_size=1, bias=False)
187
+ )
188
+
189
+ # --- WEIGHT INIT ---
190
+ self.apply(self._init_weights)
191
+ nn.init.normal_(self.byt5_mapper.weight, std=0.02)
192
+ nn.init.normal_(self.clip_mapper.weight, std=0.02)
193
+ nn.init.normal_(self.clip_image_mapper.weight, std=0.02)
194
+ torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
195
+ nn.init.constant_(self.clf[1].weight, 0) # outputs
196
+ nn.init.normal_(self.in_mapper[0].weight, std=np.sqrt(1 / num_labels)) # out mapper
197
+ self.out_mapper[-1].weight.data = self.in_mapper[0].weight.data[:, :, None, None].clone()
198
+
199
+ for level_block in self.down_blocks + self.up_blocks:
200
+ for block in level_block:
201
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
202
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks))
203
+ elif isinstance(block, TimestepBlock):
204
+ nn.init.constant_(block.mapper.weight, 0)
205
+
206
+ def _init_weights(self, m):
207
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
208
+ torch.nn.init.xavier_uniform_(m.weight)
209
+ if m.bias is not None:
210
+ nn.init.constant_(m.bias, 0)
211
+
212
+ def gen_r_embedding(self, r, max_positions=10000):
213
+ r = r * max_positions
214
+ half_dim = self.c_r // 2
215
+ emb = math.log(max_positions) / (half_dim - 1)
216
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
217
+ emb = r[:, None] * emb[None, :]
218
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
219
+ if self.c_r % 2 == 1: # zero pad
220
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
221
+ return emb
222
+
223
+ def gen_c_embeddings(self, byt5, clip, clip_image):
224
+ seq = self.byt5_mapper(byt5)
225
+ if clip is not None:
226
+ clip = self.clip_mapper(clip).view(clip.size(0), -1, self.c_cond)
227
+ seq = torch.cat([seq, clip], dim=1)
228
+ if clip_image is not None:
229
+ if isinstance(clip_image, list):
230
+ for ci in clip_image:
231
+ ci = self.clip_image_mapper(ci).view(ci.size(0), -1, self.c_cond)
232
+ seq = torch.cat([seq, ci], dim=1)
233
+ else:
234
+ clip_image = self.clip_image_mapper(clip_image).view(clip_image.size(0), -1, self.c_cond)
235
+ seq = torch.cat([seq, clip_image], dim=1)
236
+ seq = self.seq_norm(seq)
237
+ return seq
238
+
239
+ def _down_encode(self, x, r_embed, c_embed, **kwargs):
240
+ level_outputs = []
241
+ for down_block in self.down_blocks:
242
+ for block in down_block:
243
+ if isinstance(block, ResBlock):
244
+ x = block(x)
245
+ elif isinstance(block, AttnBlock):
246
+ x = block(x, c_embed, **kwargs)
247
+ elif isinstance(block, TimestepBlock):
248
+ x = block(x, r_embed)
249
+ else:
250
+ x = block(x)
251
+ level_outputs.insert(0, x)
252
+ return level_outputs
253
+
254
+ def _up_decode(self, level_outputs, r_embed, c_embed, **kwargs):
255
+ x = level_outputs[0]
256
+ for i, up_block in enumerate(self.up_blocks):
257
+ for j, block in enumerate(up_block):
258
+ if isinstance(block, ResBlock):
259
+ x = block(x, level_outputs[i] if j == 0 and i > 0 else None)
260
+ elif isinstance(block, AttnBlock):
261
+ x = block(x, c_embed, **kwargs)
262
+ elif isinstance(block, TimestepBlock):
263
+ x = block(x, r_embed)
264
+ else:
265
+ x = block(x)
266
+ return x
267
+
268
+ def forward(self, x, r, byt5, clip=None, clip_image=None, x_cat=None, **kwargs):
269
+ if x_cat is not None:
270
+ x = torch.cat([x, x_cat], dim=1)
271
+ # Process the conditioning embeddings
272
+ r_embed = self.gen_r_embedding(r)
273
+ c_embed = self.gen_c_embeddings(byt5, clip, clip_image)
274
+
275
+ # Model Blocks
276
+ x = self.embedding(self.in_mapper(x).permute(0, 3, 1, 2))
277
+ level_outputs = self._down_encode(x, r_embed, c_embed, **kwargs)
278
+ x = self._up_decode(level_outputs, r_embed, c_embed, **kwargs)
279
+ x = self.out_mapper(self.clf(x))
280
+ return x
281
+
282
+ def add_noise(self, x, t, mask=None, random_x=None):
283
+ if mask is None:
284
+ mask = (torch.rand_like(x.float()) <= t[:, None, None]).long()
285
+ if random_x is None:
286
+ random_x = torch.randint_like(x, 0, self.num_labels)
287
+ x = x * (1 - mask) + random_x * mask
288
+ return x, mask
289
+
290
+ def get_loss_weight(self, t, mask, min_val=0.3):
291
+ return 1 - (1 - mask) * ((1 - t) * (1 - min_val))[:, None, None]