XiaRho's picture
Init
8b4c6c7 verified
raw
history blame
20 kB
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from detectron2.layers import Conv2d
import fvcore.nn.weight_init as weight_init
from typing import Any, Optional, Tuple, Type
from modeling.semantic_enhanced_matting.modeling.image_encoder import Attention
from modeling.semantic_enhanced_matting.modeling.transformer import Attention as DownAttention
from modeling.semantic_enhanced_matting.feature_fusion import PositionEmbeddingRandom as ImagePositionEmbedding
from modeling.semantic_enhanced_matting.modeling.common import MLPBlock
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class ConditionConv(nn.Module):
"""
The standard bottleneck residual block without the last activation layer.
It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
"""
def __init__(
self,
in_channels,
out_channels,
bottleneck_channels,
norm=LayerNorm2d,
act_layer=nn.GELU,
conv_kernels=3,
conv_paddings=1,
condtition_channels = 1024
):
"""
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
bottleneck_channels (int): number of output channels for the 3x3
"bottleneck" conv layers.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
act_layer (callable): activation for all conv layers.
"""
super().__init__()
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
self.norm1 = norm(bottleneck_channels)
self.act1 = act_layer()
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
conv_kernels,
padding=conv_paddings,
bias=False,
)
self.norm2 = norm(bottleneck_channels)
self.act2 = act_layer()
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
self.norm3 = norm(out_channels)
self.init_weight()
self.condition_embedding = nn.Sequential(
act_layer(),
nn.Linear(condtition_channels, bottleneck_channels, bias=True)
)
def init_weight(self):
for layer in [self.conv1, self.conv2, self.conv3]:
weight_init.c2_msra_fill(layer)
for layer in [self.norm1, self.norm2]:
layer.weight.data.fill_(1.0)
layer.bias.data.zero_()
# zero init last norm layer.
self.norm3.weight.data.zero_()
self.norm3.bias.data.zero_()
# def embed_bbox_and_instance(self, bbox, instance):
# assert isinstance(instance, bool)
def forward(self, x, condition):
# [B, 64, 64, 1024]
out = x.permute(0, 3, 1, 2)
out = self.act1(self.norm1(self.conv1(out)))
out = self.conv2(out) + self.condition_embedding(condition)[:, :, None, None]
out = self.act2(self.norm2(out))
out = self.norm3(self.conv3(out))
out = x + out.permute(0, 2, 3, 1)
return out
class ConditionAdd(nn.Module):
def __init__(
self,
act_layer=nn.GELU,
condtition_channels = 1024
):
super().__init__()
self.condition_embedding = nn.Sequential(
act_layer(),
nn.Linear(condtition_channels, condtition_channels, bias=True)
)
def forward(self, x, condition):
# [B, 64, 64, 1024]
condition = self.condition_embedding(condition)[:, None, None, :]
return x + condition
class ConditionEmbedding(nn.Module):
def __init__(
self,
condition_num = 5,
pos_embedding_dim = 128,
embedding_scale = 1.0,
embedding_max_period = 10000,
embedding_flip_sin_to_cos = True,
embedding_downscale_freq_shift = 1.0,
time_embed_dim = 1024,
split_embed = False
):
super().__init__()
self.condition_num = condition_num
self.pos_embedding_dim = pos_embedding_dim
self.embedding_scale = embedding_scale
self.embedding_max_period = embedding_max_period
self.embedding_flip_sin_to_cos = embedding_flip_sin_to_cos
self.embedding_downscale_freq_shift = embedding_downscale_freq_shift
self.split_embed = split_embed
if self.split_embed:
self.linear_1 = nn.Linear(pos_embedding_dim, time_embed_dim, True)
else:
self.linear_1 = nn.Linear(condition_num * pos_embedding_dim, time_embed_dim, True)
self.act = nn.GELU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, True)
def proj_embedding(self, condition):
sample = self.linear_1(condition)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
def position_embedding(self, condition):
# [B, 5] --> [B, 5, 128] --> [B, 5 * 128]
assert condition.shape[-1] == self.condition_num
half_dim = self.pos_embedding_dim // 2
exponent = -math.log(self.embedding_max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=condition.device
)
exponent = exponent / (half_dim - self.embedding_downscale_freq_shift)
emb = torch.exp(exponent)
emb = condition[:, :, None].float() * emb[None, None, :] # [B, 5, 1] * [1, 1, 64] --> [B, 5, 64]
# scale embeddings
emb = self.embedding_scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # [B, 5, 64] --> [B, 5, 128]
# flip sine and cosine embeddings
if self.embedding_flip_sin_to_cos:
emb = torch.cat([emb[:, :, half_dim:], emb[:, :, :half_dim]], dim=-1)
# zero pad
# if self.pos_embedding_dim % 2 == 1:
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
if self.split_embed:
emb = emb.reshape(-1, emb.shape[-1])
else:
emb = emb.reshape(emb.shape[0], -1)
return emb
def forward(self, condition):
condition = self.position_embedding(condition)
condition = self.proj_embedding(condition)
return condition.float()
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.positional_encoding_gaussian_matrix = nn.Parameter(scale * torch.randn((2, num_pos_feats // 2)))
# self.register_buffer(
# "positional_encoding_gaussian_matrix",
# scale * torch.randn((2, num_pos_feats)),
# )
point_embeddings = [nn.Embedding(1, num_pos_feats) for i in range(2)]
self.point_embeddings = nn.ModuleList(point_embeddings)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
coords = self._pe_encoding(coords.to(torch.float)) # B x N x C
coords[:, 0, :] += self.point_embeddings[0].weight
coords[:, 1, :] += self.point_embeddings[1].weight
return coords
class CrossSelfAttn(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, embedding_dim=1024, num_heads=4, downsample_rate=4) -> None:
super().__init__()
self.cross_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate)
self.norm1 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim=512)
self.norm2 = nn.LayerNorm(embedding_dim)
self.self_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate)
self.norm3 = nn.LayerNorm(embedding_dim)
def forward(self, block_feat, bbox_token, feat_pe, bbox_pe):
B, H, W, C = block_feat.shape
block_feat = block_feat.reshape(B, H * W, C)
block_feat = block_feat + self.cross_attn(q=block_feat + feat_pe, k=bbox_token + bbox_pe, v=bbox_token)
block_feat = self.norm1(block_feat)
block_feat = block_feat + self.mlp(block_feat)
block_feat = self.norm2(block_feat)
concat_token = torch.concat((block_feat + feat_pe, bbox_token + bbox_pe), dim=1)
block_feat = block_feat + self.self_attn(q=concat_token, k=concat_token, v=concat_token)[:, :-bbox_token.shape[1]]
block_feat = self.norm3(block_feat)
output = block_feat.reshape(B, H, W, C)
return output
class BBoxEmbedInteract(nn.Module):
def __init__(
self,
embed_type = 'fourier',
interact_type = 'attn',
layer_num = 3
):
super().__init__()
assert embed_type in {'fourier', 'position', 'conv'}
assert interact_type in {'add', 'attn', 'cross-self-attn'}
self.embed_type = embed_type
self.interact_type = interact_type
self.layer_num = layer_num
if self.embed_type == 'fourier' and self.interact_type == 'add':
self.embed_layer = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 256)
elif self.embed_type == 'fourier':
self.embed_layer = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 256, split_embed = True)
elif self.embed_type == 'conv':
mask_in_chans = 16
activation = nn.GELU
self.embed_layer = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, 1024, kernel_size=1),
)
else:
if self.interact_type == 'add':
self.embed_layer = PositionEmbeddingRandom(num_pos_feats = 512)
else:
self.embed_layer = PositionEmbeddingRandom(num_pos_feats = 1024)
self.interact_layer = nn.ModuleList()
for _ in range(self.layer_num):
if self.interact_type == 'attn':
self.interact_layer.append(Attention(dim = 1024))
elif self.interact_type == 'add' and self.embed_type != 'conv':
self.interact_layer.append(nn.Sequential(
nn.GELU(),
nn.Linear(1024, 1024, bias=True)
))
elif self.interact_type == 'cross-self-attn':
self.interact_layer.append(CrossSelfAttn(embedding_dim=1024, num_heads=4, downsample_rate=4))
self.position_layer = ImagePositionEmbedding(num_pos_feats=1024 // 2)
def forward(self, block_feat, bbox, layer_index):
# input: [B, 1, 4], [B, 64, 64, 1024]
if layer_index == self.layer_num:
return block_feat
interact_layer = self.interact_layer[layer_index]
bbox = bbox + 0.5 # Shift to center of pixel
if self.embed_type == 'fourier' and self.interact_type == 'add':
embedding = self.embed_layer(bbox[:, 0]) # [B, 1, 4] --> reshape [B, 4] --> [B, 1024 * 1] --> reshape [B, 1, 1024]
embedding = embedding.reshape(embedding.shape[0], 1, -1)
elif self.embed_type == 'fourier':
embedding = self.embed_layer(bbox[:, 0]) # [B, 1, 4] --> reshape [B, 4] --> [B, 1024 * 4] --> reshape [B, 4, 1024]
embedding = embedding.reshape(-1, 4, embedding.shape[-1])
elif self.embed_type == 'conv':
# concat mask and img as condition
bbox_mask = torch.zeros(size=(block_feat.shape[0], 1, 256, 256), device=block_feat.device, dtype=block_feat.dtype) # [B, 1, 512, 512]
for i in range(bbox.shape[0]):
l, u, r, d = bbox[i, 0, :] / 4
bbox_mask[i, :, int(u + 0.5): int(d + 0.5), int(l + 0.5): int(r + 0.5)] = 1.0 # int(x + 0.5) = round(x)
embedding = self.embed_layer(bbox_mask) # [B, 1024, 64, 64]
elif self.embed_type == 'position':
embedding = self.embed_layer(bbox.reshape(-1, 2, 2), (1024, 1024)) # [B, 1, 4] --> reshape [B, 2, 2] --> [B, 2, 1024/512]
if self.interact_type == 'add':
embedding = embedding.reshape(embedding.shape[0], 1, -1)
# add position embedding to block_feat
pe = self.position_layer(size=(64, 64)).reshape(1, 64, 64, 1024)
block_feat = block_feat + pe
if self.interact_type == 'attn':
add_token_num = embedding.shape[1]
B, H, W, C = block_feat.shape
block_feat = block_feat.reshape(B, H * W, C)
concat_token = torch.concat((block_feat, embedding), dim=1) # [B, 64 * 64 + 2, 1024]
output_token = interact_layer.forward_token(concat_token)[:, :-add_token_num]
output = output_token.reshape(B, H, W, C)
elif self.embed_type == 'conv':
output = block_feat + embedding.permute(0, 2, 3, 1)
elif self.interact_type == 'add':
output = interact_layer(embedding[:, None]) + block_feat
elif self.interact_type == 'cross-self-attn':
output = interact_layer(block_feat, embedding)
return output
# reuse the position_point_embedding in prompt_encoder
class BBoxInteract(nn.Module):
def __init__(
self,
position_point_embedding,
point_weight,
layer_num = 3,
):
super().__init__()
self.position_point_embedding = position_point_embedding
self.point_weight = point_weight
for _, p in self.named_parameters():
p.requires_grad = False
self.layer_num = layer_num
self.input_image_size = (1024, 1024)
self.interact_layer = nn.ModuleList()
for _ in range(self.layer_num):
self.interact_layer.append(CrossSelfAttn(embedding_dim=1024, num_heads=4, downsample_rate=4))
@torch.no_grad()
def get_bbox_token(self, boxes):
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.position_point_embedding.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_weight[2].weight
corner_embedding[:, 1, :] += self.point_weight[3].weight
corner_embedding = F.interpolate(corner_embedding[..., None], size=(1024, 1), mode='bilinear', align_corners=False)[..., 0]
return corner_embedding # [B, 2, 1024]
@torch.no_grad()
def get_position_embedding(self, size=(64, 64)):
pe = self.position_point_embedding(size=size)
pe = F.interpolate(pe.permute(1, 2, 0)[..., None], size=(1024, 1), mode='bilinear', align_corners=False)[..., 0][None]
pe = pe.reshape(1, -1, 1024)
return pe # [1, 64 * 64, 1024]
def forward(self, block_feat, bbox, layer_index):
# input: [B, 1, 4], [B, 64, 64, 1024]
if layer_index == self.layer_num:
return block_feat
interact_layer = self.interact_layer[layer_index]
pe = self.get_position_embedding()
bbox_token = self.get_bbox_token(bbox)
output = interact_layer(block_feat, bbox_token, feat_pe=pe, bbox_pe=bbox_token)
return output
class InOutBBoxCrossSelfAttn(nn.Module):
def __init__(self, embedding_dim=1024, num_heads=4, downsample_rate=4) -> None:
super().__init__()
self.self_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate)
self.norm1 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim=embedding_dim // 2)
self.norm2 = nn.LayerNorm(embedding_dim)
self.cross_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate)
self.norm3 = nn.LayerNorm(embedding_dim)
def forward(self, in_box_token, out_box_token):
# self-attn
short_cut = in_box_token
in_box_token = self.norm1(in_box_token)
in_box_token = self.self_attn(q=in_box_token, k=in_box_token, v=in_box_token)
in_box_token = short_cut + in_box_token
# mlp
in_box_token = in_box_token + self.mlp(self.norm2(in_box_token))
# cross-attn
short_cut = in_box_token
in_box_token = self.norm3(in_box_token)
in_box_token = self.cross_attn(q=in_box_token, k=out_box_token, v=out_box_token)
in_box_token = short_cut + in_box_token
return in_box_token
class BBoxInteractInOut(nn.Module):
def __init__(
self,
num_heads = 4,
downsample_rate = 4,
layer_num = 3,
):
super().__init__()
self.layer_num = layer_num
self.input_image_size = (1024, 1024)
self.interact_layer = nn.ModuleList()
for _ in range(self.layer_num):
self.interact_layer.append(InOutBBoxCrossSelfAttn(embedding_dim=1024, num_heads=num_heads, downsample_rate=downsample_rate))
def forward(self, block_feat, bbox, layer_index):
# input: [B, 1, 4], [B, 64, 64, 1024]
if layer_index == self.layer_num:
return block_feat
interact_layer = self.interact_layer[layer_index]
# split_in_out_bbox_token
bbox = torch.round(bbox / self.input_image_size[0] * (block_feat.shape[1] - 1)).int()
for i in range(block_feat.shape[0]):
in_bbox_mask = torch.zeros((block_feat.shape[1], block_feat.shape[2]), dtype=bool, device=bbox.device)
in_bbox_mask[bbox[i, 0, 1]: bbox[i, 0, 3], bbox[i, 0, 0]: bbox[i, 0, 2]] = True
in_bbox_token = block_feat[i: i + 1, in_bbox_mask, :]
out_bbox_token = block_feat[i: i + 1, ~in_bbox_mask, :]
block_feat[i, in_bbox_mask, :] = interact_layer(in_bbox_token, out_bbox_token)
return block_feat
if __name__ == '__main__':
# emded = ConditionEmbedding()
# input = torch.tensor([[100, 200, 300, 400, 512], [100, 200, 300, 400, 1024]])
# print(input.shape)
# output = emded(input) # [B, 5] --> [B, 5 * 128] --> [B, 1024]
emded = BBoxEmbedInteract(
embed_type = 'position',
interact_type = 'cross-self-attn'
)
input = torch.tensor([[[100, 200, 300, 400]], [[100, 200, 300, 400]]]) # [B, 1, 4]
print(input.shape)
output = emded(torch.randn((2, 64, 64, 1024)), input) # [B, 5] --> [B, 5 * 128] --> [B, 1024]