joytag / Models.py
fancyfeast
Bugfix
df9e86f
raw
history blame contribute delete
No virus
43.6 kB
import json
from pathlib import Path
from typing import Optional
import torch
import torch.backends.cuda
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from transformers.activations import QuickGELUActivation
import math
from einops.layers.torch import Rearrange
import einops
MODEL_CONFIGS = {
# Custom models trained from scratch
# "Standard" definitions:
# name | layers | width | heads
# B | 12 | 768 | 12
# L | 24 | 1024 | 16
# H | 32 | 1280 | 16
# G | 48 | 1664 | 16
# e | 56 | 1792 | 16
# 22 | 48 | 6144 | 48
# B/16, 224, PaLM, GELU
'CustomTest6': {
'class': 'CLIPLikeModel',
'embedding_dim': 768,
'num_attention_heads': 12,
'activation_cls': nn.GELU,
'num_channels': 3,
'patch_size': 16,
'use_palm_alt': True,
'num_layers': 12,
'use_mha_alt': False,
'good_dropout': False,
},
# GAP head + Sinusoidal positional embeddings + 448 image size
'CustomTest18': {
'class': 'CLIPLikeModel',
'embedding_dim': 768,
'num_attention_heads': 12,
'activation_cls': nn.GELU,
'num_channels': 3,
'patch_size': 16,
'use_palm_alt': True,
'num_layers': 12,
'use_mha_alt': False,
'good_dropout': False,
'use_gap_head': True,
'sine_positional_embeddings': True,
},
# SW Model + B/16 + ASL + 448 image size
# cutout_max_pct = 0
# mixup_alpha = 0.8
# noise_level = 2
# random_resize_method = true
# total_labels = 6549
'SWModel1': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': False},
# Sinusoidal positional embeddings
'SWModel2': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
# Sinusoidal positional embeddings + 224 image size + L/14
'SWModel3': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-1, 'use_sine': True},
# Sinusoidal positional embeddings + 224 image size + G/14
'SWModel4': {'class': 'ViT', 'num_blocks': 48, 'patch_size': 14, 'd_model': 1664, 'mlp_dim': 1664*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-1, 'use_sine': True},
# Sinusoidal positional embeddings + focal loss
'SWModel5': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
'SWModel6': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
'SWModel7': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
'SWModel8': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
'SWModel9': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
'SWModel10': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
'SWModel11': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0, 'use_sine': True},
# Trying head_mean_after
'SWModel12': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'head_mean_after': True},
# Fat boy
'SWModel13': {'class': 'ViT', 'num_blocks': 6, 'patch_size': 16, 'd_model': 1536, 'mlp_dim': 1536*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True},
# L/14
'SWModel14': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-1, 'use_sine': True},
'SWModel15': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-5, 'use_sine': True},
'SWModel16': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.10, 'layerscale_init': 1e-1, 'use_sine': True},
'SWModel16f': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.10, 'layerscale_init': 1e-1, 'use_sine': True},
'SWModel22': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 14, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.20, 'layerscale_init': 1e-1, 'use_sine': True},
'SWModel25': {'class': 'ViT', 'num_blocks': 24, 'patch_size': 16, 'd_model': 1024, 'mlp_dim': 1024*4, 'num_heads': 16, 'stochdepth_rate': 0.15, 'layerscale_init': 1e-1, 'use_sine': True, 'cnn_stem': 'conv:c=128;ln;relu;conv:c=256;ln;relu;conv:c=512;ln;relu;conv:c=1024;ln;relu;conv:c=1024,s=1,k=1,p=0'},
# CNN stem
'SWModel18': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;bn;relu;conv:c=128;bn;relu;conv:c=256;bn;relu;conv:c=512;bn;relu;conv:c=768,s=1,k=1'},
'SWModel19': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;bn;relu;conv:c=128;bn;relu;conv:c=128,s=1;bn;relu;conv:c=256;bn;relu;conv:c=256,s=1;bn;relu;conv:c=512;bn;relu;conv:c=768,s=1,k=1,p=0'},
'SWModel20': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;ln;relu;conv:c=128;ln;relu;conv:c=256;ln;relu;conv:c=512;ln;relu;conv:c=768,s=1,k=1,p=0'},
'SWModel21': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;ln;gelu;conv:c=128;ln;gelu;conv:c=256;ln;gelu;conv:c=512;ln;gelu;conv:c=768,s=1,k=1,p=0'},
'SWModel23': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;ln;relu;conv:c=128;ln;relu;conv:c=256;ln;relu;conv:c=512;ln;relu;conv:c=768,s=1,k=1,p=0'},
'SWModel24': {'class': 'ViT', 'num_blocks': 12, 'patch_size': 16, 'd_model': 768, 'mlp_dim': 768*4, 'num_heads': 12, 'stochdepth_rate': 0.05, 'use_sine': True, 'cnn_stem': 'conv:c=64;ln;relu;conv:c=128;ln;relu;conv:c=256;ln;relu;conv:c=512;ln;relu;conv:c=768,s=1,k=1,p=0'},
# H/14
'SWModel17': {'class': 'ViT', 'num_blocks': 32, 'patch_size': 14, 'd_model': 1280, 'mlp_dim': 1280*4, 'num_heads': 16, 'stochdepth_rate': 0.05, 'layerscale_init': 1e-1, 'use_sine': True},
'SWModel26': {'class': 'ViT', 'num_blocks': 32, 'patch_size': 14, 'd_model': 1280, 'mlp_dim': 1280*4, 'num_heads': 16, 'stochdepth_rate': 0.15, 'layerscale_init': 1e-1, 'use_sine': True},
}
class VisionModel(nn.Module):
image_size: int
n_tags: int
def __init__(self, image_size: int, n_tags: int):
super().__init__()
self.image_size = image_size
self.n_tags = n_tags
@staticmethod
def load_model(path: Path | str, device: str | None = None) -> 'VisionModel':
"""
Load a model from a directory.
:param path: The directory containing the model.
:return: The model, the image size, and the number of tags.
"""
with open(Path(path) / 'config.json', 'r') as f:
config = json.load(f)
if (Path(path) / 'model.safetensors').exists():
from safetensors.torch import load_file
resume = load_file(Path(path) / 'model.safetensors', device='cpu')
else:
resume = torch.load(Path(path) / 'model.pt', map_location=torch.device('cpu'))['model']
model_classes = VisionModel.__subclasses__()
model_cls = next(cls for cls in model_classes if cls.__name__ == config['class'])
model = model_cls(**{k: v for k, v in config.items() if k != 'class'})
model.load(resume)
if device is not None:
model = model.to(device)
return model
@staticmethod
def from_config(config: dict) -> 'VisionModel':
model_classes = VisionModel.__subclasses__()
model_cls = next(cls for cls in model_classes if cls.__name__ == config['class'])
return model_cls(**{k: v for k, v in config.items() if k != 'class'})
def get_optimized_parameters(self, lr: float):
raise NotImplementedError
def save(self):
raise NotImplementedError
def load(self, state_dict):
raise NotImplementedError
def basic_calculate_loss(preds: dict[str, torch.Tensor], batch: dict, pos_weight: torch.Tensor | None, loss_type: str):
def asl_helper(preds, target):
p = F.softmax(preds, dim=1)
xs_pos = p.clamp(min=1e-6)
xs_neg = (1 - p).clamp(min=1e-6)
los_pos = torch.log(torch.gather(xs_pos, 1, target.unsqueeze(1))).sum()
los_neg = torch.log(xs_neg)
los_neg = los_neg.sum() - torch.gather(los_neg, 1, target.unsqueeze(1)).sum()
loss = los_pos + los_neg
return -loss
if loss_type == "ce":
loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'])
elif loss_type == "weighted":
loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], pos_weight=pos_weight)
elif loss_type == "focal":
gamma = 2
p = torch.sigmoid(preds['tags'])
ce_loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], reduction='none')
p_t = p * batch['tags'] + (1 - p) * (1 - batch['tags'])
loss = ce_loss * ((1 - p_t) ** gamma)
loss = loss.mean()
elif loss_type == "focal2":
gamma = 2
p = torch.sigmoid(preds['tags'])
ce_loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], reduction='none')
p_t = p * batch['tags'] + (1 - p) * (1 - batch['tags'])
loss = ce_loss * ((1 - p_t) ** gamma) * 256
loss = loss.mean()
elif loss_type == "asl":
p = torch.sigmoid(preds['tags'])
xs_pos = p
xs_neg = 1 - p
los_pos = batch['tags'] * torch.log(xs_pos.clamp(min=1e-6))
los_neg = (1 - batch['tags']) * torch.log(xs_neg.clamp(min=1e-6))
loss = los_pos + los_neg
loss = -loss.sum()
# Rating
loss = loss + asl_helper(preds['rating'], batch['rating'])
# Score
loss = loss + asl_helper(preds['score'], batch['score'])
elif loss_type == "asl2":
p = torch.sigmoid(preds['tags'])
xs_pos = p
xs_neg = 1 - p
los_pos = batch['tags'] * torch.log(xs_pos.clamp(min=1e-6))
los_neg = (1 - batch['tags']) * torch.log(xs_neg.clamp(min=1e-6))
loss = -los_pos - los_neg
loss = loss.sum()
elif loss_type == "asl3":
p = torch.sigmoid(preds['tags'])
xs_pos = p
xs_neg = 1 - p
los_pos = batch['tags'] * torch.log(xs_pos.clamp(min=1e-6))
los_neg = (1 - batch['tags']) * torch.log(xs_neg.clamp(min=1e-6))
loss = -los_pos - los_neg
loss = loss.mean()
elif loss_type == "asl4":
p = torch.sigmoid(preds['tags'])
xs_pos = p
xs_neg = 1 - p
los_pos = batch['tags'] * torch.log(xs_pos.clamp(min=1e-6))
los_neg = (1 - batch['tags']) * torch.log(xs_neg.clamp(min=1e-6))
loss = -los_pos - los_neg
loss = loss.mean() * 128
elif loss_type == "asl5":
loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], pos_weight=pos_weight) * 128
elif loss_type == "asl6":
loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], pos_weight=pos_weight) * 256
elif loss_type == "asl7":
loss = F.binary_cross_entropy_with_logits(preds['tags'], batch['tags'], pos_weight=pos_weight) * 2
else:
raise ValueError(f"Invalid loss type: {loss_type}")
return loss
class CLIPMlp(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, activation_cls):
super().__init__()
self.activation_fn = activation_cls()
self.fc1 = nn.Linear(hidden_size, intermediate_size)
self.fc2 = nn.Linear(intermediate_size, hidden_size)
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class FastCLIPAttention2(nn.Module):
"""Fast Attention module for CLIP-like. This is NOT a drop-in replacement for CLIPAttention, since it adds additional flexibility. Mainly uses xformers."""
def __init__(self, hidden_size: int, out_dim: int, num_attention_heads: int, out_seq_len: Optional[int] = None, norm_qk: bool = False):
super().__init__()
self.out_seq_len = out_seq_len
self.embed_dim = hidden_size
self.out_dim = out_dim
self.norm_qk = norm_qk
self.num_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
assert self.head_dim * num_attention_heads == self.embed_dim, "embed_dim must be divisible by num_attention_heads"
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.kv_proj = nn.Linear(self.embed_dim, self.embed_dim * 2)
self.out_proj = nn.Linear(self.embed_dim, self.out_dim)
if self.norm_qk:
self.query_norm = nn.LayerNorm(self.embed_dim)
self.key_norm = nn.LayerNorm(self.embed_dim)
#def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
# return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()
def forward(self, query_states: torch.Tensor, kv_states: torch.Tensor) -> torch.Tensor:
bsz, src_len, embed_dim = kv_states.size()
if self.out_seq_len is not None:
tgt_len = self.out_seq_len
else:
tgt_len = src_len
kv_states = self.kv_proj(kv_states) # (bsz, src_len, embed_dim * 2)
q_states = self.q_proj(query_states[:, :tgt_len]) # (bsz, tgt_len, embed_dim)
# NOTE: It is not clear if LayerNorm should be applied to the embed_dim, or to the head_dim
if self.norm_qk:
q_states = self.query_norm(q_states).type(q_states.dtype)
k_states = self.key_norm(kv_states[:, :, :embed_dim]).type(kv_states.dtype)
v_states = kv_states[:, :, embed_dim:]
else:
k_states = kv_states[:, :, :embed_dim]
v_states = kv_states[:, :, embed_dim:]
q_states = q_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) # (bsz, num_heads, tgt_len, head_dim)
k_states = k_states.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) # (bsz, num_heads, src_len, head_dim)
v_states = v_states.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) # (bsz, num_heads, src_len, head_dim)
# Performs scale of query_states, attention, and softmax
with torch.backends.cuda.sdp_kernel(enable_math=False):
x = F.scaled_dot_product_attention(q_states, k_states, v_states) # (bsz, num_heads, tgt_len, head_dim)
x = x.transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim) # (bsz, tgt_len, embed_dim)
# Projection
x = self.out_proj(x) # (bsz, tgt_len, out_dim)
return x
class SkipInit(nn.Module):
def __init__(self, hidden_size: int, channel_wise: bool, init_scale: float):
super().__init__()
self.hidden_size = hidden_size
self.channel_wise = channel_wise
self.init_scale = init_scale
if self.channel_wise:
self.scale = nn.Parameter(torch.ones(hidden_size) * init_scale)
else:
self.scale = nn.Parameter(torch.tensor(init_scale))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale
class FastCLIPEncoderLayer(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
out_seq_len: Optional[int],
activation_cls = QuickGELUActivation,
use_palm_alt: bool = False,
norm_qk: bool = False,
skip_init: Optional[float] = None,
stochastic_depth: Optional[float] = None,
):
super().__init__()
self.use_palm_alt = use_palm_alt
self.stochastic_depth = stochastic_depth
self.self_attn = FastCLIPAttention2(
hidden_size=hidden_size,
out_dim=hidden_size,
num_attention_heads=num_attention_heads,
out_seq_len=out_seq_len,
norm_qk=norm_qk,
)
self.mlp = CLIPMlp(hidden_size, 4 * hidden_size, activation_cls)
self.layer_norm1 = nn.LayerNorm(hidden_size)
if not use_palm_alt:
self.layer_norm2 = nn.LayerNorm(hidden_size)
if skip_init is not None:
self.attn_skip_init = SkipInit(hidden_size, channel_wise=True, init_scale=skip_init)
self.mlp_skip_init = SkipInit(hidden_size, channel_wise=True, init_scale=skip_init)
else:
self.attn_skip_init = nn.Identity()
self.mlp_skip_init = nn.Identity()
def forward(self, hidden_states: torch.Tensor):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
if not self.use_palm_alt:
hidden_states = self.self_attn(query_states=hidden_states, kv_states=hidden_states)
hidden_states = self.attn_skip_init(hidden_states)
hidden_states = hidden_states + residual[:, :hidden_states.size(1)]
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp_skip_init(hidden_states)
hidden_states = hidden_states + residual
else:
# An alternative implementation inspired by the PALM paper
# By performing the attention and MLP in parallel it's possible to fuse the linear projections of the attention and MLP layers
# We don't do that here yet, but that supposedly improves efficiency without hurting performance
attn = self.self_attn(query_states=hidden_states, kv_states=hidden_states)
attn = self.attn_skip_init(attn)
mlp = self.mlp(hidden_states[:, :attn.size(1)])
mlp = self.mlp_skip_init(mlp)
if self.stochastic_depth is not None:
attn = torchvision.ops.stochastic_depth(attn, self.stochastic_depth, mode='row', training=self.training)
mlp = torchvision.ops.stochastic_depth(mlp, self.stochastic_depth, mode='row', training=self.training)
hidden_states = residual[:, :attn.size(1)] + attn + mlp
return hidden_states
def sinusoidal_position_embedding(width: int, height: int, depth: int, dtype, device, temperature = 10000):
"""
Sinusoidal position embedding. Returns a flat tensor of shape (h * w, d).
"""
assert depth % 4 == 0, "Embedding dimension must be divisible by 4."
y, x = torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device), indexing="ij")
omega = torch.arange(depth // 4, device=device) / (depth // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
embedding = torch.cat([x.sin(), x.cos(), y.sin(), y.cos()], dim=1)
return embedding.type(dtype)
class CLIPEmbeddingLayer(nn.Module):
def __init__(self, hidden_size: int, num_channels: int, image_size: int, patch_size: int, patch_dropout: float = 0.0, good_dropout: bool = False, dpn: bool = False, sine_positional_embeddings: bool = False):
super().__init__()
assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
seq_len = (image_size // patch_size) ** 2
self.patch_dropout = patch_dropout
self.hidden_size = hidden_size
self.good_dropout = good_dropout
self.dpn = dpn
self.sine_positional_embeddings = sine_positional_embeddings
self.patch_size = patch_size
self.patch_embeddings = nn.Conv2d(
in_channels=num_channels,
out_channels=hidden_size,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
if not self.sine_positional_embeddings:
self.positional_embeddings = nn.Embedding(seq_len, hidden_size)
self.register_buffer("position_ids", torch.arange(seq_len))
if self.dpn:
self.to_patch_embeddings = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
nn.LayerNorm(3 * patch_size * patch_size),
nn.Linear(3 * patch_size * patch_size, hidden_size),
nn.LayerNorm(hidden_size),
)
else:
self.to_patch_embeddings = nn.Conv2d(
in_channels=num_channels,
out_channels=hidden_size,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
B, C, H, W = pixel_values.shape
assert H % self.patch_size == 0, f"Input image height ({H}) needs to be divisible by the patch size ({self.patch_size})."
assert W % self.patch_size == 0, f"Input image width ({W}) needs to be divisible by the patch size ({self.patch_size})."
if self.dpn:
patches = self.to_patch_embeddings(pixel_values)
else:
patches = self.to_patch_embeddings(pixel_values)
patches = patches.flatten(2).transpose(1, 2)
seq_len = patches.shape[1]
patch_dropout = int(math.ceil((1.0 - self.patch_dropout) * seq_len))
if self.sine_positional_embeddings:
position_embeddings = sinusoidal_position_embedding(W // self.patch_size, H // self.patch_size, self.hidden_size, pixel_values.dtype, pixel_values.device)
else:
position_embeddings = self.positional_embeddings(self.position_ids)
if patch_dropout == seq_len or not self.training:
embeddings = patches + position_embeddings
elif self.good_dropout:
# Pick random patches to drop out
# The "good_dropout" variant uses random permutations for each batch item, but is slightly slower and involves more code
# The below method is a nice trick to generate a batch of random permutations.
# Torch (as of 1.13) doesn't have a built-in function to do this, and a for loop of torch.randperm is slow.
# Based on some benchmarks I measured the generation of the mask and the fetching to be only 50% slower than the non-"good_dropout" variant.
# And the time taken here is only a fraction of the time spent performing the embedding convolution.
# Generate a matrix of random numbers between 0 and 1 of shape (B, seq_len)
patch_mask = torch.rand(B, seq_len, device=patches.device)
# For each batch tensor, use argsort to convert the random numbers into a permutation of the patch indices
patch_mask = torch.argsort(patch_mask, dim=1)
# Truncate
patch_mask = patch_mask[:, :patch_dropout]
embeddings = patches.gather(1, patch_mask.unsqueeze(-1).expand(-1, -1, self.hidden_size)) + position_embeddings[patch_mask]
else:
# The non-"good_dropout" variant uses a single random permutation for all batch items, but is faster and uses less code
indices = torch.randperm(seq_len, device=pixel_values.device)[:patch_dropout]
embeddings = patches[:, indices, :] + position_embeddings[indices.expand(1, -1)]
return embeddings
class MHAPoolingHead(nn.Module):
def __init__(self, hidden_size: int, num_attention_heads: int, activation_cls, out_dim: int, alt_style: bool, norm_qk: bool):
super().__init__()
self.out_dim = out_dim if not alt_style else hidden_size
self.probe = nn.Parameter(torch.randn(hidden_size))
self.mlp = CLIPMlp(hidden_size, 4 * hidden_size, activation_cls)
self.layer_norm = nn.LayerNorm(hidden_size)
self.pooling_head = nn.Linear(hidden_size, 1)
self.self_attn = FastCLIPAttention2(
hidden_size=hidden_size,
out_dim=self.out_dim,
num_attention_heads=num_attention_heads,
out_seq_len=1,
norm_qk=norm_qk,
)
self.mlp = CLIPMlp(self.out_dim, 4 * self.out_dim, activation_cls)
self.layer_norm1 = nn.LayerNorm(hidden_size)
self.layer_norm2 = nn.LayerNorm(self.out_dim)
if alt_style:
self.final_proj = nn.Linear(hidden_size, out_dim)
else:
self.final_proj = nn.Identity()
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.layer_norm1(hidden_states)
query_states = self.probe.unsqueeze(0).unsqueeze(0).expand(hidden_states.size(0), 1, -1)
hidden_states = self.self_attn(query_states=query_states, kv_states=hidden_states)
# We don't use a residual connection here because the out_dim is different from the hidden_size
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
hidden_states = self.final_proj(hidden_states)
return hidden_states.squeeze(1)
class GAPHead(nn.Module):
def __init__(self, hidden_size: int, out_dim: int):
super().__init__()
self.norm = nn.LayerNorm(hidden_size)
self.proj = nn.Linear(hidden_size, out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.mean(dim=1)
x = self.norm(x)
x = self.proj(x)
return x
class CLIPLikeModel(VisionModel):
def __init__(
self,
n_tags: int,
embedding_dim: int,
num_attention_heads: int,
activation_cls,
num_channels: int,
image_size: int,
patch_size: int,
patch_dropout: float,
use_palm_alt: bool,
num_layers: int,
use_mha_alt: bool,
loss_type: str,
good_dropout: bool=False,
dpn: bool=False,
sine_positional_embeddings: bool=False,
norm_qk: bool = False,
no_wd_bias: bool = False,
use_gap_head: bool = False,
skip_init: Optional[float] = None,
stochastic_depth: Optional[float] = None,
):
super().__init__(image_size, n_tags)
out_dim = n_tags
self.n_tags = n_tags
self.loss_type = loss_type
self.no_wd_bias = no_wd_bias
stochastic_depth_space = torch.linspace(0, stochastic_depth, num_layers) if stochastic_depth is not None else None
self.embedding_layer = CLIPEmbeddingLayer(embedding_dim, num_channels, image_size, patch_size, patch_dropout, good_dropout, dpn, sine_positional_embeddings)
self.pre_layer_norm = nn.LayerNorm(embedding_dim)
self.encoder_layers = nn.ModuleList([FastCLIPEncoderLayer(
hidden_size=embedding_dim,
num_attention_heads=num_attention_heads,
out_seq_len=None,
activation_cls=activation_cls,
use_palm_alt=use_palm_alt,
norm_qk=norm_qk,
skip_init=skip_init,
stochastic_depth=stochastic_depth_space[i].item() if stochastic_depth_space is not None else None,
) for i in range(num_layers)])
if use_gap_head:
self.pooling_head = GAPHead(embedding_dim, out_dim)
else:
self.pooling_head = MHAPoolingHead(embedding_dim, num_attention_heads, activation_cls, out_dim, use_mha_alt, norm_qk=norm_qk)
def forward(self, batch):
hidden_states = self.embedding_layer(batch['image'])
hidden_states = self.pre_layer_norm(hidden_states)
for layer in self.encoder_layers:
hidden_states = layer(hidden_states)
preds = self.pooling_head(hidden_states)
result = {
'tags': preds,
}
return result
def calculate_loss(self, preds, batch, pos_weight):
return basic_calculate_loss(preds, batch, pos_weight, self.loss_type)
def get_optimized_parameters(self, lr: float):
if self.no_wd_bias:
return self.get_optimized_parameters_no_wd_bias()
else:
return self.parameters()
def get_optimized_parameters_no_wd_bias(self):
decay = []
no_decay = []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name.endswith(".bias"):
no_decay.append(param)
print(f'No decay: {name}')
else:
decay.append(param)
return [
{'params': decay},
{'params': no_decay, 'weight_decay': 0.},
]
def save(self):
return self.state_dict()
def load(self, state_dict):
self.load_state_dict(state_dict)
class MaskedAutoEncoderViT(nn.Module):
def __init__(
self,
n_tags: int,
embedding_dim: int,
num_attention_heads: int,
activation_cls,
num_channels: int,
image_size: int,
patch_size: int,
num_layers: int,
loss_type: str,
sine_positional_embeddings: bool=False,
decoder_embedding_dim: int = 512,
decoder_num_attention_heads: int = 8,
decoder_num_layers: int = 6,
decoder_force_projection: bool = False,
masking_ratio: float = 0.75,
mae_loss_weight: float = 1.0,
mae_normalize_targets: bool = False,
mae_post_norm: bool = False,
):
super().__init__()
self.n_tags = n_tags
self.seq_len = (image_size // patch_size) ** 2
self.embedding_dim = embedding_dim
self.decoder_embedding_dim = decoder_embedding_dim
self.sine_positional_embeddings = sine_positional_embeddings
self.image_size = image_size
self.patch_size = patch_size
self.masking_ratio = masking_ratio
self.loss_type = loss_type
self.mae_loss_weight = mae_loss_weight
self.mae_normalize_targets = mae_normalize_targets
if not self.sine_positional_embeddings:
self.positional_embeddings = nn.Embedding(self.seq_len, embedding_dim)
self.decoder_positional_embeddings = nn.Embedding(self.seq_len, decoder_embedding_dim)
self.register_buffer("position_ids", torch.arange(self.seq_len))
self.to_patches = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
self.patch_embedder = nn.Linear(num_channels * patch_size * patch_size, embedding_dim)
# Encoder
self.pre_layer_norm = nn.LayerNorm(embedding_dim)
self.encoder_layers = nn.ModuleList([FastCLIPEncoderLayer(
hidden_size=embedding_dim,
num_attention_heads=num_attention_heads,
out_seq_len=None,
activation_cls=activation_cls,
use_palm_alt=True,
norm_qk=False,
skip_init=None,
) for _ in range(num_layers)])
# Head for classification
self.pooling_head = GAPHead(embedding_dim, n_tags)
# Decoder
if embedding_dim != decoder_embedding_dim or decoder_force_projection:
self.encoder_to_decoder_proj = nn.Linear(embedding_dim, decoder_embedding_dim)
else:
self.encoder_to_decoder_proj = nn.Identity()
self.decoder_pre_layer_norm = nn.LayerNorm(decoder_embedding_dim)
self.decoder_layers = nn.ModuleList([FastCLIPEncoderLayer(
hidden_size=decoder_embedding_dim,
num_attention_heads=decoder_num_attention_heads,
out_seq_len=None,
activation_cls=activation_cls,
use_palm_alt=True,
norm_qk=False,
skip_init=None,
) for _ in range(decoder_num_layers)])
if mae_post_norm:
self.decoder_to_pixel_values = nn.Sequential(
nn.LayerNorm(decoder_embedding_dim),
nn.Linear(decoder_embedding_dim, num_channels * patch_size * patch_size)
)
else:
self.decoder_to_pixel_values = nn.Linear(decoder_embedding_dim, num_channels * patch_size * patch_size)
self.mask_token = nn.Parameter(torch.zeros(decoder_embedding_dim))
torch.nn.init.normal_(self.mask_token, std=0.02)
def forward(self, batch):
pixel_values = batch['image']
device = pixel_values.device
B, C, H, W = pixel_values.shape
assert H % self.patch_size == 0, f"Input image height ({H}) needs to be divisible by the patch size ({self.patch_size})."
assert W % self.patch_size == 0, f"Input image width ({W}) needs to be divisible by the patch size ({self.patch_size})."
# Convert image to patches (B, seq_len, C * patch_size * patch_size)
patches = self.to_patches(pixel_values)
seq_len = patches.shape[1]
num_masked = int(self.masking_ratio * seq_len)
# For each batch tensor, use argsort to convert the random numbers into a permutation of the patch indices
# From this we can get the masked and unmasked indices
patch_mask = torch.rand(B, seq_len, device=device)
patch_mask = torch.argsort(patch_mask, dim=1)
masked_indices, unmasked_indices = patch_mask[:, :num_masked], patch_mask[:, num_masked:]
batch_range = torch.arange(B, device=device)[:, None]
# Masked and unmasked patches
unmasked_patches = patches[batch_range, unmasked_indices]
masked_patches = patches[batch_range, masked_indices]
# Embed unmasked patches for the encoder (B, seq_len, embedding_dim)
tokens = self.patch_embedder(unmasked_patches)
if self.sine_positional_embeddings:
position_embeddings = sinusoidal_position_embedding(W // self.patch_size, H // self.patch_size, self.embedding_dim, pixel_values.dtype, device)
decoder_position_embeddings = sinusoidal_position_embedding(W // self.patch_size, H // self.patch_size, self.decoder_embedding_dim, pixel_values.dtype, device)
else:
position_embeddings = self.positional_embeddings(self.position_ids)
decoder_position_embeddings = self.decoder_positional_embeddings(self.position_ids)
# Add position embeddings
tokens = tokens + position_embeddings[unmasked_indices]
# Run the encoder
encoded_tokens = self.pre_layer_norm(tokens)
for layer in self.encoder_layers:
encoded_tokens = layer(encoded_tokens)
# Label predictions
if self.training:
preds = self.pooling_head(encoded_tokens)
else:
# During inference, classify using the entire image
# But we'll do the usual for the MAE part, just so we can see how MAE is performing during validation
tokens = self.patch_embedder(patches)
tokens = tokens + position_embeddings
tokens = self.pre_layer_norm(tokens)
for layer in self.encoder_layers:
tokens = layer(tokens)
preds = self.pooling_head(tokens)
# Projection for the decoder and position embeddings
decoder_tokens = self.encoder_to_decoder_proj(encoded_tokens)
decoder_tokens = decoder_tokens + decoder_position_embeddings[unmasked_indices]
# Fill in the masked patches
mask_tokens = einops.repeat(self.mask_token, 'd -> b n d', b = B, n = num_masked)
mask_tokens = mask_tokens + decoder_position_embeddings[masked_indices]
decoder_tokens = torch.cat([decoder_tokens, mask_tokens], dim=1)
# Run the decoder
decoded_tokens = self.decoder_pre_layer_norm(decoder_tokens)
for layer in self.decoder_layers:
decoded_tokens = layer(decoded_tokens)
# Only predict the masked patches
# All the masked patches are at the end of the sequence
decoded_tokens = decoded_tokens[:, -num_masked:]
pred_pixel_values = self.decoder_to_pixel_values(decoded_tokens)
# Calculate the mae loss
if self.mae_normalize_targets:
# Normalize each patch by its mean and variance. The ViCHA paper says this provides better results
means = masked_patches.mean(dim=-1, keepdim=True)
vars = masked_patches.var(dim=-1, keepdim=True)
target = (masked_patches - means) / (vars + 1e-6)**0.5
mae_loss = F.mse_loss(pred_pixel_values, target)
else:
mae_loss = F.mse_loss(pred_pixel_values, masked_patches)
mae_loss = mae_loss * self.mae_loss_weight
return {
'tags': preds,
'mae_loss': mae_loss,
}
def calculate_loss(self, preds, batch, pos_weight):
return basic_calculate_loss(preds, batch, pos_weight, self.loss_type) + preds['mae_loss']
def get_optimized_parameters(self, lr: float):
return self.parameters()
def save(self):
return self.state_dict()
def load(self, state_dict):
self.load_state_dict(state_dict)
class StochDepth(nn.Module):
def __init__(self, drop_rate: float, scale_by_keep: bool = False):
super().__init__()
self.drop_rate = drop_rate
self.scale_by_keep = scale_by_keep
def forward(self, x):
if not self.training:
return x
batch_size = x.shape[0]
r = torch.rand((batch_size, 1, 1), device=x.device)
keep_prob = 1 - self.drop_rate
binary_tensor = torch.floor(keep_prob + r)
if self.scale_by_keep:
x = x / keep_prob
return x * binary_tensor
class SkipInitChannelwise(nn.Module):
def __init__(self, channels, init_val=1e-6):
super().__init__()
self.channels = channels
self.init_val = init_val
self.skip = nn.Parameter(torch.ones(channels) * init_val)
def forward(self, x):
return x * self.skip
class PosEmbedding(nn.Module):
def __init__(self, d_model: int, max_len: int, use_sine: bool, patch_size: int):
super().__init__()
self.d_model = d_model
self.max_len = max_len
self.use_sine = use_sine
self.patch_size = patch_size
if not self.use_sine:
self.embedding = nn.Embedding(max_len, d_model)
nn.init.trunc_normal_(self.embedding.weight, std=0.02)
self.register_buffer("position_ids", torch.arange(max_len))
def forward(self, x, width: int, height: int):
if self.use_sine:
position_embeddings = sinusoidal_position_embedding(width // self.patch_size, height // self.patch_size, self.d_model, x.dtype, x.device)
else:
position_embeddings = self.embedding(self.position_ids)
return x + position_embeddings
class MLPBlock(nn.Module):
def __init__(self, d_model: int, d_ff: int, stochdepth_rate: float):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.activation = nn.GELU()
if stochdepth_rate > 0:
self.stochdepth = StochDepth(stochdepth_rate, scale_by_keep=True)
else:
self.stochdepth = None
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
if self.stochdepth is not None:
x = self.stochdepth(x)
x = self.linear2(x)
return x
class ViTBlock(nn.Module):
def __init__(self, num_heads: int, d_model: int, d_ff: int, layerscale_init: float, stochdepth_rate: float):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# MHA
self.norm1 = nn.LayerNorm(d_model)
self.qkv_proj = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
self.skip_init1 = SkipInitChannelwise(channels=d_model, init_val=layerscale_init)
self.stochdepth1 = StochDepth(stochdepth_rate, scale_by_keep=True) if stochdepth_rate > 0 else None
# MLP
self.norm2 = nn.LayerNorm(d_model)
self.mlp = MLPBlock(d_model, d_ff, stochdepth_rate)
self.skip_init2 = SkipInitChannelwise(channels=d_model, init_val=layerscale_init)
self.stochdepth2 = StochDepth(stochdepth_rate, scale_by_keep=True) if stochdepth_rate > 0 else None
def forward(self, x):
bsz, src_len, embed_dim = x.shape
out = x
out = self.norm1(out)
# MHA
qkv_states = self.qkv_proj(out).split(self.d_model, dim=-1)
q_states = qkv_states[0].view(bsz, src_len, self.num_heads, embed_dim // self.num_heads).transpose(1, 2) # (bsz, num_heads, src_len, embed_dim // num_heads)
k_states = qkv_states[1].view(bsz, src_len, self.num_heads, embed_dim // self.num_heads).transpose(1, 2) # (bsz, num_heads, src_len, embed_dim // num_heads)
v_states = qkv_states[2].view(bsz, src_len, self.num_heads, embed_dim // self.num_heads).transpose(1, 2) # (bsz, num_heads, src_len, embed_dim // num_heads)
with torch.backends.cuda.sdp_kernel(enable_math=False):
out = F.scaled_dot_product_attention(q_states, k_states, v_states) # (bsz, num_heads, tgt_len, head_dim)
out = out.transpose(1, 2).contiguous().view(bsz, src_len, embed_dim) # (bsz, tgt_len, embed_dim)
out = self.out_proj(out)
out = self.skip_init1(out)
if self.stochdepth1 is not None:
out = self.stochdepth1(out)
x = out + x
out = self.norm2(x)
out = self.mlp(out)
out = self.skip_init2(out)
if self.stochdepth2 is not None:
out = self.stochdepth2(out)
out = out + x
return out
def CaiT_LayerScale_init(network_depth):
if network_depth <= 18:
return 1e-1
elif network_depth <= 24:
return 1e-5
else:
return 1e-6
class CNNLayerNorm(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 3)
x = self.norm(x)
x = x.transpose(1, 3)
return x
class CNNStem(nn.Module):
def __init__(self, config: str):
super().__init__()
self.config = config
layers = []
channels = 3
for line in config.split(";"):
ty, line = line.split(":") if ":" in line else (line, "")
options = line.split(",")
options = [o.split("=") for o in options] if line else []
options = {k: v for k, v in options}
if ty == 'conv':
layers.append(nn.Conv2d(
in_channels=channels,
out_channels=int(options['c']),
kernel_size=int(options['k'] if 'k' in options else 3),
stride=int(options['s'] if 's' in options else 2),
bias=True,
padding=int(options['p'] if 'p' in options else 1),
))
channels = int(options['c'])
elif ty == 'bn':
layers.append(nn.BatchNorm2d(channels))
elif ty == 'ln':
layers.append(CNNLayerNorm(channels))
elif ty == 'relu':
layers.append(nn.ReLU())
elif ty == 'gelu':
layers.append(nn.GELU())
self.conv = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class ViT(VisionModel):
def __init__(self,
n_tags: int,
image_size: int,
num_blocks: int,
patch_size: int,
d_model: int,
mlp_dim: int,
num_heads: int,
stochdepth_rate: float,
use_sine: bool,
loss_type: str,
layerscale_init: Optional[float] = None,
head_mean_after: bool = False,
cnn_stem: str | None = None,
patch_dropout: float = 0.0,
):
super().__init__(image_size, n_tags)
#assert image_size % patch_size == 0, "image_size must be divisible by patch_size"
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
out_dim = n_tags
self.n_tags = n_tags
self.loss_type = loss_type
self.patch_size = patch_size
self.head_mean_after = head_mean_after
self.patch_dropout = patch_dropout
layerscale_init = CaiT_LayerScale_init(num_blocks) if layerscale_init is None else layerscale_init
self.patch_embeddings = nn.Conv2d(
in_channels=3,
out_channels=d_model,
kernel_size=patch_size,
stride=patch_size,
bias=True,
) if cnn_stem is None else CNNStem(cnn_stem)
self.pos_embedding = PosEmbedding(d_model, (image_size // patch_size) ** 2, use_sine=use_sine, patch_size=patch_size)
self.blocks = nn.ModuleList([
ViTBlock(num_heads, d_model, mlp_dim, layerscale_init, stochdepth_rate)
for _ in range(num_blocks)
])
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, out_dim)
def forward(self, batch, return_embeddings=False, return_loss: bool = False, pos_weight = None):
B, C, H, W = batch['image'].shape
assert H % self.patch_size == 0, f"Input image height ({H}) needs to be divisible by the patch size ({self.patch_size})."
assert W % self.patch_size == 0, f"Input image width ({W}) needs to be divisible by the patch size ({self.patch_size})."
x = self.patch_embeddings(batch['image']) # (bsz, d_model, patch_num, patch_num)
x = x.flatten(2).transpose(1, 2) # (bsz, patch_num ** 2, d_model)
x = self.pos_embedding(x, W, H) # (bsz, patch_num ** 2, d_model)
# Patch dropout
seq_len = x.shape[1]
patch_dropout = int(math.ceil((1.0 - self.patch_dropout) * seq_len))
if patch_dropout != seq_len:
# Generate a matrix of random numbers between 0 and 1 of shape (B, seq_len)
patch_mask = torch.rand(B, seq_len, device=x.device)
# For each batch tensor, use argsort to convert the random numbers into a permutation of the patch indices
patch_mask = torch.argsort(patch_mask, dim=1)
# Truncate
patch_mask = patch_mask[:, :patch_dropout]
x = x.gather(1, patch_mask.unsqueeze(-1).expand(-1, -1, x.shape[-1]))
#indices = torch.randperm(seq_len, device=x.device)[:patch_dropout]
#x = x[:, indices, :]
# Transformer
for block in self.blocks:
x = block(x)
# Head
result = {}
x = self.norm(x)
if self.head_mean_after:
x = self.head(x)
x = x.mean(dim=1)
else:
x = x.mean(dim=1)
if return_embeddings:
result['embeddings'] = x
x = self.head(x)
result['tags'] = x
if return_loss:
result['loss'] = self.calculate_loss(result, batch, pos_weight)
return result
def calculate_loss(self, preds, batch, pos_weight):
return basic_calculate_loss(preds, batch, pos_weight, self.loss_type)
def get_optimized_parameters(self, lr: float):
return self.parameters()
def save(self):
return self.state_dict()
def load(self, state_dict):
if 'head.weight' in state_dict and 'head.bias' in state_dict and state_dict['head.weight'].shape[0] == (self.n_tags + 9):
# Support old models which included 3 rating and 6 score dimensions
state_dict['head.weight'] = state_dict['head.weight'][:self.n_tags]
state_dict['head.bias'] = state_dict['head.bias'][:self.n_tags]
self.load_state_dict(state_dict)