import torch from torch import nn from torch.nn import functional as F class Basic_Conv3x3(nn.Module): """ Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. """ def __init__( self, in_chans, out_chans, stride=2, padding=1, ): super().__init__() self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) self.bn = nn.BatchNorm2d(out_chans) self.relu = nn.ReLU(True) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class ConvStream(nn.Module): """ Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. """ def __init__( self, in_chans = 4, out_chans = [48, 96, 192], ): super().__init__() self.convs = nn.ModuleList() self.conv_chans = out_chans.copy() self.conv_chans.insert(0, in_chans) for i in range(len(self.conv_chans)-1): in_chan_ = self.conv_chans[i] out_chan_ = self.conv_chans[i+1] self.convs.append( Basic_Conv3x3(in_chan_, out_chan_) ) def forward(self, x): out_dict = {'D0': x} for i in range(len(self.convs)): x = self.convs[i](x) name_ = 'D'+str(i+1) out_dict[name_] = x return out_dict class Fusion_Block(nn.Module): """ Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer. """ def __init__( self, in_chans, out_chans, ): super().__init__() self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1) def forward(self, x, D): F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) out = torch.cat([D, F_up], dim=1) out = self.conv(out) return out class Matting_Head(nn.Module): """ Simple Matting Head, containing only conv3x3 and conv1x1 layers. """ def __init__( self, in_chans = 32, mid_chans = 16, ): super().__init__() self.matting_convs = nn.Sequential( nn.Conv2d(in_chans, mid_chans, 3, 1, 1), nn.BatchNorm2d(mid_chans), nn.ReLU(True), nn.Conv2d(mid_chans, 1, 1, 1, 0) ) def forward(self, x): x = self.matting_convs(x) return x class Detail_Capture(nn.Module): """ Simple and Lightweight Detail Capture Module for ViT Matting. """ def __init__( self, in_chans = [384, 1], img_chans=4, convstream_out = [48, 96, 192], fusion_out = [256, 128, 64, 32], ): super().__init__() assert len(fusion_out) == len(convstream_out) + 1 self.convstream = ConvStream(in_chans=img_chans, out_chans=convstream_out) self.conv_chans = self.convstream.conv_chans # [4, 48, 96, 192] self.fusion_blks = nn.ModuleList() self.fus_channs = fusion_out.copy() self.fus_channs.insert(0, in_chans[0]) # [384, 256, 128, 64, 32] for i in range(len(self.fus_channs)-1): in_channels = self.fus_channs[i] + self.conv_chans[-(i+1)] if i != 2 else in_chans[1] + self.conv_chans[-(i+1)] # [256 + 192 = 448, 256 + 96 = 352, 128 + 48 = 176, 64 + 4 = 68] out_channels = self.fus_channs[i+1] # [256, 128, 64, 32] self.fusion_blks.append( Fusion_Block( in_chans = in_channels, out_chans = out_channels, ) ) self.matting_head = Matting_Head( # 32 --> 1 in_chans = fusion_out[-1], ) def forward(self, features, images): detail_features = self.convstream(images) # [1, 4, 672, 992] --> D0: [1, 4, 672, 992], D1: [1, 48, 336, 496], D2: [1, 96, 168, 248], D3: [1, 192, 84, 124] for i in range(len(self.fusion_blks)): # D3 d_name_ = 'D'+str(len(self.fusion_blks)-i-1) features = self.fusion_blks[i](features, detail_features[d_name_]) phas = torch.sigmoid(self.matting_head(features)) return {'phas': phas} class Ori_Detail_Capture(nn.Module): """ Simple and Lightweight Detail Capture Module for ViT Matting. """ def __init__( self, in_chans = 384, img_chans=4, convstream_out = [48, 96, 192], fusion_out = [256, 128, 64, 32], ): super().__init__() assert len(fusion_out) == len(convstream_out) + 1 self.convstream = ConvStream(in_chans = img_chans) self.conv_chans = self.convstream.conv_chans self.fusion_blks = nn.ModuleList() self.fus_channs = fusion_out.copy() self.fus_channs.insert(0, in_chans) for i in range(len(self.fus_channs)-1): self.fusion_blks.append( Fusion_Block( in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)], out_chans = self.fus_channs[i+1], ) ) self.matting_head = Matting_Head( in_chans = fusion_out[-1], ) def forward(self, features, images): detail_features = self.convstream(images) for i in range(len(self.fusion_blks)): d_name_ = 'D'+str(len(self.fusion_blks)-i-1) features = self.fusion_blks[i](features, detail_features[d_name_]) phas = torch.sigmoid(self.matting_head(features)) return {'phas': phas}