SEMat / modeling /decoder /detail_capture.py
XiaRho's picture
Init
8b4c6c7 verified
raw
history blame
5.68 kB
import torch
from torch import nn
from torch.nn import functional as F
class Basic_Conv3x3(nn.Module):
"""
Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
"""
def __init__(
self,
in_chans,
out_chans,
stride=2,
padding=1,
):
super().__init__()
self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
self.bn = nn.BatchNorm2d(out_chans)
self.relu = nn.ReLU(True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class ConvStream(nn.Module):
"""
Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
"""
def __init__(
self,
in_chans = 4,
out_chans = [48, 96, 192],
):
super().__init__()
self.convs = nn.ModuleList()
self.conv_chans = out_chans.copy()
self.conv_chans.insert(0, in_chans)
for i in range(len(self.conv_chans)-1):
in_chan_ = self.conv_chans[i]
out_chan_ = self.conv_chans[i+1]
self.convs.append(
Basic_Conv3x3(in_chan_, out_chan_)
)
def forward(self, x):
out_dict = {'D0': x}
for i in range(len(self.convs)):
x = self.convs[i](x)
name_ = 'D'+str(i+1)
out_dict[name_] = x
return out_dict
class Fusion_Block(nn.Module):
"""
Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
"""
def __init__(
self,
in_chans,
out_chans,
):
super().__init__()
self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
def forward(self, x, D):
F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
out = torch.cat([D, F_up], dim=1)
out = self.conv(out)
return out
class Matting_Head(nn.Module):
"""
Simple Matting Head, containing only conv3x3 and conv1x1 layers.
"""
def __init__(
self,
in_chans = 32,
mid_chans = 16,
):
super().__init__()
self.matting_convs = nn.Sequential(
nn.Conv2d(in_chans, mid_chans, 3, 1, 1),
nn.BatchNorm2d(mid_chans),
nn.ReLU(True),
nn.Conv2d(mid_chans, 1, 1, 1, 0)
)
def forward(self, x):
x = self.matting_convs(x)
return x
class Detail_Capture(nn.Module):
"""
Simple and Lightweight Detail Capture Module for ViT Matting.
"""
def __init__(
self,
in_chans = [384, 1],
img_chans=4,
convstream_out = [48, 96, 192],
fusion_out = [256, 128, 64, 32],
):
super().__init__()
assert len(fusion_out) == len(convstream_out) + 1
self.convstream = ConvStream(in_chans=img_chans, out_chans=convstream_out)
self.conv_chans = self.convstream.conv_chans # [4, 48, 96, 192]
self.fusion_blks = nn.ModuleList()
self.fus_channs = fusion_out.copy()
self.fus_channs.insert(0, in_chans[0]) # [384, 256, 128, 64, 32]
for i in range(len(self.fus_channs)-1):
in_channels = self.fus_channs[i] + self.conv_chans[-(i+1)] if i != 2 else in_chans[1] + self.conv_chans[-(i+1)] # [256 + 192 = 448, 256 + 96 = 352, 128 + 48 = 176, 64 + 4 = 68]
out_channels = self.fus_channs[i+1] # [256, 128, 64, 32]
self.fusion_blks.append(
Fusion_Block(
in_chans = in_channels,
out_chans = out_channels,
)
)
self.matting_head = Matting_Head( # 32 --> 1
in_chans = fusion_out[-1],
)
def forward(self, features, images):
detail_features = self.convstream(images) # [1, 4, 672, 992] --> D0: [1, 4, 672, 992], D1: [1, 48, 336, 496], D2: [1, 96, 168, 248], D3: [1, 192, 84, 124]
for i in range(len(self.fusion_blks)): # D3
d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
features = self.fusion_blks[i](features, detail_features[d_name_])
phas = torch.sigmoid(self.matting_head(features))
return {'phas': phas}
class Ori_Detail_Capture(nn.Module):
"""
Simple and Lightweight Detail Capture Module for ViT Matting.
"""
def __init__(
self,
in_chans = 384,
img_chans=4,
convstream_out = [48, 96, 192],
fusion_out = [256, 128, 64, 32],
):
super().__init__()
assert len(fusion_out) == len(convstream_out) + 1
self.convstream = ConvStream(in_chans = img_chans)
self.conv_chans = self.convstream.conv_chans
self.fusion_blks = nn.ModuleList()
self.fus_channs = fusion_out.copy()
self.fus_channs.insert(0, in_chans)
for i in range(len(self.fus_channs)-1):
self.fusion_blks.append(
Fusion_Block(
in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
out_chans = self.fus_channs[i+1],
)
)
self.matting_head = Matting_Head(
in_chans = fusion_out[-1],
)
def forward(self, features, images):
detail_features = self.convstream(images)
for i in range(len(self.fusion_blks)):
d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
features = self.fusion_blks[i](features, detail_features[d_name_])
phas = torch.sigmoid(self.matting_head(features))
return {'phas': phas}