import math import torch import torch.nn as nn import torch.nn.functional as F import einops import numpy as np import models from modules.common_ckpt import Linear, Conv2d, AttnBlock, ResBlock, LayerNorm2d #from modules.common_ckpt import AttnBlock, from einops import rearrange import torch.fft as fft from modules.speed_util import checkpoint def batched_linear_mm(x, wb): # x: (B, N, D1); wb: (B, D1 + 1, D2) or (D1 + 1, D2) one = torch.ones(*x.shape[:-1], 1, device=x.device) return torch.matmul(torch.cat([x, one], dim=-1), wb) def make_coord_grid(shape, range, device=None): """ Args: shape: tuple range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim Returns: grid: shape (*shape, ) """ l_lst = [] for i, s in enumerate(shape): l = (0.5 + torch.arange(s, device=device)) / s if isinstance(range[0], list) or isinstance(range[0], tuple): minv, maxv = range[i] else: minv, maxv = range l = minv + (maxv - minv) * l l_lst.append(l) grid = torch.meshgrid(*l_lst, indexing='ij') grid = torch.stack(grid, dim=-1) return grid def init_wb(shape): weight = torch.empty(shape[1], shape[0] - 1) nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) bias = torch.empty(shape[1], 1) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(bias, -bound, bound) return torch.cat([weight, bias], dim=1).t().detach() def init_wb_rewrite(shape): weight = torch.empty(shape[1], shape[0] - 1) torch.nn.init.xavier_uniform_(weight) bias = torch.empty(shape[1], 1) torch.nn.init.xavier_uniform_(bias) return torch.cat([weight, bias], dim=1).t().detach() class HypoMlp(nn.Module): def __init__(self, depth, in_dim, out_dim, hidden_dim, use_pe, pe_dim, out_bias=0, pe_sigma=1024): super().__init__() self.use_pe = use_pe self.pe_dim = pe_dim self.pe_sigma = pe_sigma self.depth = depth self.param_shapes = dict() if use_pe: last_dim = in_dim * pe_dim else: last_dim = in_dim for i in range(depth): # for each layer the weight cur_dim = hidden_dim if i < depth - 1 else out_dim self.param_shapes[f'wb{i}'] = (last_dim + 1, cur_dim) last_dim = cur_dim self.relu = nn.ReLU() self.params = None self.out_bias = out_bias def set_params(self, params): self.params = params def convert_posenc(self, x): w = torch.exp(torch.linspace(0, np.log(self.pe_sigma), self.pe_dim // 2, device=x.device)) x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1) x = torch.cat([torch.cos(np.pi * x), torch.sin(np.pi * x)], dim=-1) return x def forward(self, x): B, query_shape = x.shape[0], x.shape[1: -1] x = x.view(B, -1, x.shape[-1]) if self.use_pe: x = self.convert_posenc(x) #print('in line 79 after pos embedding', x.shape) for i in range(self.depth): x = batched_linear_mm(x, self.params[f'wb{i}']) if i < self.depth - 1: x = self.relu(x) else: x = x + self.out_bias x = x.view(B, *query_shape, -1) return x class Attention(nn.Module): def __init__(self, dim, n_head, head_dim, dropout=0.): super().__init__() self.n_head = n_head inner_dim = n_head * head_dim self.to_q = nn.Sequential( nn.SiLU(), Linear(dim, inner_dim )) self.to_kv = nn.Sequential( nn.SiLU(), Linear(dim, inner_dim * 2)) self.scale = head_dim ** -0.5 # self.to_out = nn.Sequential( # Linear(inner_dim, dim), # nn.Dropout(dropout), # ) def forward(self, fr, to=None): if to is None: to = fr q = self.to_q(fr) k, v = self.to_kv(to).chunk(2, dim=-1) q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v]) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = F.softmax(dots, dim=-1) # b h n n out = torch.matmul(attn, v) out = einops.rearrange(out, 'b h n d -> b n (h d)') return out class FeedForward(nn.Module): def __init__(self, dim, ff_dim, dropout=0.): super().__init__() self.net = nn.Sequential( Linear(dim, ff_dim), nn.GELU(), #GlobalResponseNorm(ff_dim), nn.Dropout(dropout), Linear(ff_dim, dim) ) def forward(self, x): return self.net(x) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x): return self.fn(self.norm(x)) #TransInr(ind=2048, ch=256, n_head=16, head_dim=16, n_groups=64, f_dim=256, time_dim=self.c_r, t_conds = []) class TransformerEncoder(nn.Module): def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList() for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)), PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)), ])) def forward(self, x): for norm_attn, norm_ff in self.layers: x = x + norm_attn(x) x = x + norm_ff(x) return x class ImgrecTokenizer(nn.Module): def __init__(self, input_size=32*32, patch_size=1, dim=768, padding=0, img_channels=16): super().__init__() if isinstance(patch_size, int): patch_size = (patch_size, patch_size) if isinstance(padding, int): padding = (padding, padding) self.patch_size = patch_size self.padding = padding self.prefc = nn.Linear(patch_size[0] * patch_size[1] * img_channels, dim) self.posemb = nn.Parameter(torch.randn(input_size, dim)) def forward(self, x): #print(x.shape) p = self.patch_size x = F.unfold(x, p, stride=p, padding=self.padding) # (B, C * p * p, L) #print('in line 185 after unfoding', x.shape) x = x.permute(0, 2, 1).contiguous() ttt = self.prefc(x) x = self.prefc(x) + self.posemb[:x.shape[1]].unsqueeze(0) return x class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() self.conv1 = Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv1(x) return self.sigmoid(x) class TimestepBlock_res(nn.Module): def __init__(self, c, c_timestep, conds=['sca']): super().__init__() self.mapper = Linear(c_timestep, c * 2) self.conds = conds for cname in conds: setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) def forward(self, x, t): #print(x.shape, t.shape, self.conds, 'in line 269') t = t.chunk(len(self.conds) + 1, dim=1) a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) for i, c in enumerate(self.conds): ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) a, b = a + ac, b + bc return x * (1 + a) + b def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module class ScaleNormalize_res(nn.Module): def __init__(self, c, scale_c, conds=['sca']): super().__init__() self.c_r = scale_c self.mapping = TimestepBlock_res(c, scale_c, conds=conds) self.t_conds = conds self.alpha = nn.Conv2d(c, c, kernel_size=1) self.gamma = nn.Conv2d(c, c, kernel_size=1) self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 emb = math.log(max_positions) / (half_dim - 1) emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() emb = r[:, None] * emb[None, :] emb = torch.cat([emb.sin(), emb.cos()], dim=1) if self.c_r % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode='constant') return emb def forward(self, x, std_size=24*24): scale_val = math.sqrt(math.log(x.shape[-2] * x.shape[-1], std_size)) scale_val = torch.ones(x.shape[0]).to(x.device)*scale_val scale_val_f = self.gen_r_embedding(scale_val) for c in self.t_conds: t_cond = torch.zeros_like(scale_val) scale_val_f = torch.cat([scale_val_f, self.gen_r_embedding(t_cond)], dim=1) f = self.mapping(x, scale_val_f) return f + x class TransInr_withnorm(nn.Module): def __init__(self, ind=2048, ch=16, n_head=12, head_dim=64, n_groups=64, f_dim=768, time_dim=2048, t_conds=[]): super().__init__() self.input_layer= nn.Conv2d(ind, ch, 1) self.tokenizer = ImgrecTokenizer(dim=ch, img_channels=ch) #self.hyponet = HypoMlp(depth=12, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) #self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=12, n_head=n_head, head_dim=f_dim // n_head, ff_dim=3*f_dim, ) self.hyponet = HypoMlp(depth=2, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=1, n_head=n_head, head_dim=f_dim // n_head, ff_dim=f_dim) #self.transformer_encoder = TransInr( ch=ch, n_head=16, head_dim=16, n_groups=64, f_dim=ch, time_dim=time_dim, t_conds = []) self.base_params = nn.ParameterDict() n_wtokens = 0 self.wtoken_postfc = nn.ModuleDict() self.wtoken_rng = dict() for name, shape in self.hyponet.param_shapes.items(): self.base_params[name] = nn.Parameter(init_wb(shape)) g = min(n_groups, shape[1]) assert shape[1] % g == 0 self.wtoken_postfc[name] = nn.Sequential( nn.LayerNorm(f_dim), nn.Linear(f_dim, shape[0] - 1), ) self.wtoken_rng[name] = (n_wtokens, n_wtokens + g) n_wtokens += g self.wtokens = nn.Parameter(torch.randn(n_wtokens, f_dim)) self.output_layer= nn.Conv2d(ch, ind, 1) self.mapp_t = TimestepBlock_res( ind, time_dim, conds = t_conds) self.hr_norm = ScaleNormalize_res(ind, 64, conds=[]) self.normalize_final = nn.Sequential( LayerNorm2d(ind, elementwise_affine=False, eps=1e-6), ) self.toout = nn.Sequential( Linear( ind*2, ind // 4), nn.GELU(), Linear( ind // 4, ind) ) self.apply(self._init_weights) mask = torch.zeros((1, 1, 32, 32)) h, w = 32, 32 center_h, center_w = h // 2, w // 2 low_freq_h, low_freq_w = h // 4, w // 4 mask[:, :, center_h-low_freq_h:center_h+low_freq_h, center_w-low_freq_w:center_w+low_freq_w] = 1 self.mask = mask def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) #nn.init.constant_(self.last.weight, 0) def adain(self, feature_a, feature_b): norm_mean = torch.mean(feature_a, dim=(2, 3), keepdim=True) norm_std = torch.std(feature_a, dim=(2, 3), keepdim=True) #feature_a = F.interpolate(feature_a, feature_b.shape[2:]) feature_b = (feature_b - feature_b.mean(dim=(2, 3), keepdim=True)) / (1e-8 + feature_b.std(dim=(2, 3), keepdim=True)) * norm_std + norm_mean return feature_b def forward(self, target_shape, target, dtokens, t_emb): #print(target.shape, dtokens.shape, 'in line 290') hlr, wlr = dtokens.shape[2:] original = dtokens dtokens = self.input_layer(dtokens) dtokens = self.tokenizer(dtokens) B = dtokens.shape[0] wtokens = einops.repeat(self.wtokens, 'n d -> b n d', b=B) #print(wtokens.shape, dtokens.shape) trans_out = self.transformer_encoder(torch.cat([dtokens, wtokens], dim=1)) trans_out = trans_out[:, -len(self.wtokens):, :] params = dict() for name, shape in self.hyponet.param_shapes.items(): wb = einops.repeat(self.base_params[name], 'n m -> b n m', b=B) w, b = wb[:, :-1, :], wb[:, -1:, :] l, r = self.wtoken_rng[name] x = self.wtoken_postfc[name](trans_out[:, l: r, :]) x = x.transpose(-1, -2) # (B, shape[0] - 1, g) w = F.normalize(w * x.repeat(1, 1, w.shape[2] // x.shape[2]), dim=1) wb = torch.cat([w, b], dim=1) params[name] = wb coord = make_coord_grid(target_shape[2:], (-1, 1), device=dtokens.device) coord = einops.repeat(coord, 'h w d -> b h w d', b=dtokens.shape[0]) self.hyponet.set_params(params) ori_up = F.interpolate(original.float(), target_shape[2:]) hr_rec = self.output_layer(rearrange(self.hyponet(coord), 'b h w c -> b c h w')) + ori_up #print(hr_rec.shape, target.shape, torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1).shape, 'in line 537') output = self.toout(torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) #print(output.shape, 'in line 540') #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)* 0.3 output = self.mapp_t(output, t_emb) output = self.normalize_final(output) output = self.hr_norm(output) #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) #output = self.mapp_t(output, t_emb) #output = self.weight(output) * output return output class LayerNorm2d(nn.LayerNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x): return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) class GlobalResponseNorm(nn.Module): "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x if __name__ == '__main__': #ef __init__(self, ch, n_head, head_dim, n_groups): trans_inr = TransInr(16, 24, 32, 64).cuda() input = torch.randn((1, 16, 24, 24)).cuda() source = torch.randn((1, 16, 16, 16)).cuda() t = torch.randn((1, 128)).cuda() output, hr = trans_inr(input, t, source) total_up = sum([ param.nelement() for param in trans_inr.parameters()]) print(output.shape, hr.shape, total_up /1e6 )