"""PyTorch layer for extracting image features for the film_net interpolator. The feature extractor implemented here converts an image pyramid into a pyramid of deep features. The feature pyramid serves a similar purpose as U-Net architecture's encoder, but we use a special cascaded architecture described in Multi-view Image Fusion [1]. For comprehensiveness, below is a short description of the idea. While the description is a bit involved, the cascaded feature pyramid can be used just like any image feature pyramid. Why cascaded architeture? ========================= To understand the concept it is worth reviewing a traditional feature pyramid first: *A traditional feature pyramid* as in U-net or in many optical flow networks is built by alternating between convolutions and pooling, starting from the input image. It is well known that early features of such architecture correspond to low level concepts such as edges in the image whereas later layers extract semantically higher level concepts such as object classes etc. In other words, the meaning of the filters in each resolution level is different. For problems such as semantic segmentation and many others this is a desirable property. However, the asymmetric features preclude sharing weights across resolution levels in the feature extractor itself and in any subsequent neural networks that follow. This can be a downside, since optical flow prediction, for instance is symmetric across resolution levels. The cascaded feature architecture addresses this shortcoming. How is it built? ================ The *cascaded* feature pyramid contains feature vectors that have constant length and meaning on each resolution level, except few of the finest ones. The advantage of this is that the subsequent optical flow layer can learn synergically from many resolutions. This means that coarse level prediction can benefit from finer resolution training examples, which can be useful with moderately sized datasets to avoid overfitting. The cascaded feature pyramid is built by extracting shallower subtree pyramids, each one of them similar to the traditional architecture. Each subtree pyramid S_i is extracted starting from each resolution level: image resolution 0 -> S_0 image resolution 1 -> S_1 image resolution 2 -> S_2 ... If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid is constructed by concatenating features as follows (assuming subtree depth=3): lvl feat_0 = concat( S_0_0 ) feat_1 = concat( S_1_0 S_0_1 ) feat_2 = concat( S_2_0 S_1_1 S_0_2 ) feat_3 = concat( S_3_0 S_2_1 S_1_2 ) feat_4 = concat( S_4_0 S_3_1 S_2_2 ) feat_5 = concat( S_5_0 S_4_1 S_3_2 ) .... In above, all levels except feat_0 and feat_1 have the same number of features with similar semantic meaning. This enables training a single optical flow predictor module shared by levels 2,3,4,5... . For more details and evaluation see [1]. [1] Multi-view Image Fusion, Trinidad et al. 2019 """ from typing import List import torch from torch import nn from torch.nn import functional as F from util import Conv2d class SubTreeExtractor(nn.Module): """Extracts a hierarchical set of features from an image. This is a conventional, hierarchical image feature extractor, that extracts [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels. Each level is followed by average pooling. """ def __init__(self, in_channels=3, channels=64, n_layers=4): super().__init__() convs = [] for i in range(n_layers): convs.append(nn.Sequential( Conv2d(in_channels, (channels << i), 3), Conv2d((channels << i), (channels << i), 3) )) in_channels = channels << i self.convs = nn.ModuleList(convs) def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]: """Extracts a pyramid of features from the image. Args: image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS. n: number of pyramid levels to extract. This can be less or equal to options.sub_levels given in the __init__. Returns: The pyramid of features, starting from the finest level. Each element contains the output after the last convolution on the corresponding pyramid level. """ head = image pyramid = [] for i, layer in enumerate(self.convs): head = layer(head) pyramid.append(head) if i < n - 1: head = F.avg_pool2d(head, kernel_size=2, stride=2) return pyramid class FeatureExtractor(nn.Module): """Extracts features from an image pyramid using a cascaded architecture. """ def __init__(self, in_channels=3, channels=64, sub_levels=4): super().__init__() self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels) self.sub_levels = sub_levels def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: """Extracts a cascaded feature pyramid. Args: image_pyramid: Image pyramid as a list, starting from the finest level. Returns: A pyramid of cascaded features. """ sub_pyramids: List[List[torch.Tensor]] = [] for i in range(len(image_pyramid)): # At each level of the image pyramid, creates a sub_pyramid of features # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor. # We use the same instance since we want to share the weights. # # However, we cap the depth of the sub_pyramid so we don't create features # that are beyond the coarsest level of the cascaded feature pyramid we # want to generate. capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels) sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels)) # Below we generate the cascades of features on each level of the feature # pyramid. Assuming sub_levels=3, The layout of the features will be # as shown in the example on file documentation above. feature_pyramid: List[torch.Tensor] = [] for i in range(len(image_pyramid)): features = sub_pyramids[i][0] for j in range(1, self.sub_levels): if j <= i: features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) feature_pyramid.append(features) return feature_pyramid