"""The film_net frame interpolator main model code. Basics ====== The film_net is an end-to-end learned neural frame interpolator implemented as a PyTorch model. It has the following inputs and outputs: Inputs: x0: image A. x1: image B. time: desired sub-frame time. Outputs: image: the predicted in-between image at the chosen time in range [0, 1]. Additional outputs include forward and backward warped image pyramids, flow pyramids, etc., that can be visualized for debugging and analysis. Note that many training sets only contain triplets with ground truth at time=0.5. If a model has been trained with such training set, it will only work well for synthesizing frames at time=0.5. Such models can only generate more in-between frames using recursion. Architecture ============ The inference consists of three main stages: 1) feature extraction 2) warping 3) fusion. On high-level, the architecture has similarities to Context-aware Synthesis for Video Frame Interpolation [1], but the exact architecture is closer to Multi-view Image Fusion [2] with some modifications for the frame interpolation use-case. Feature extraction stage employs the cascaded multi-scale architecture described in [2]. The advantage of this architecture is that coarse level flow prediction can be learned from finer resolution image samples. This is especially useful to avoid overfitting with moderately sized datasets. The warping stage uses a residual flow prediction idea that is similar to PWC-Net [3], Multi-view Image Fusion [2] and many others. The fusion stage is similar to U-Net's decoder where the skip connections are connected to warped image and feature pyramids. This is described in [2]. Implementation Conventions ==================== Pyramids -------- Throughtout the model, all image and feature pyramids are stored as python lists with finest level first followed by downscaled versions obtained by successively halving the resolution. The depths of all pyramids are determined by options.pyramid_levels. The only exception to this is internal to the feature extractor, where smaller feature pyramids are temporarily constructed with depth options.sub_levels. Color ranges & gamma -------------------- The model code makes no assumptions on whether the images are in gamma or linearized space or what is the range of RGB color values. So a model can be trained with different choices. This does not mean that all the choices lead to similar results. In practice the model has been proven to work well with RGB scale = [0,1] with gamma-space images (i.e. not linearized). [1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018 [2] Multi-view Image Fusion, Trinidad et al, 2019 [3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume """ from typing import Dict, List import torch from torch import nn import util from feature_extractor import FeatureExtractor from fusion import Fusion from pyramid_flow_estimator import PyramidFlowEstimator class Interpolator(nn.Module): def __init__( self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), ): super().__init__() self.pyramid_levels = pyramid_levels self.fusion_pyramid_levels = fusion_pyramid_levels self.extract = FeatureExtractor(3, filters, sub_levels) self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters) self.fuse = Fusion(sub_levels, specialized_levels, filters) def shuffle_images(self, x0, x1): return [ util.build_image_pyramid(x0, self.pyramid_levels), util.build_image_pyramid(x1, self.pyramid_levels) ] def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]: image_pyramids = self.shuffle_images(x0, x1) # Siamese feature pyramids: feature_pyramids = [self.extract(image_pyramids[0]), self.extract(image_pyramids[1])] # Predict forward flow. forward_residual_flow_pyramid = self.predict_flow(feature_pyramids[0], feature_pyramids[1]) # Predict backward flow. backward_residual_flow_pyramid = self.predict_flow(feature_pyramids[1], feature_pyramids[0]) # Concatenate features and images: # Note that we keep up to 'fusion_pyramid_levels' levels as only those # are used by the fusion module. forward_flow_pyramid = util.flow_pyramid_synthesis(forward_residual_flow_pyramid)[:self.fusion_pyramid_levels] backward_flow_pyramid = util.flow_pyramid_synthesis(backward_residual_flow_pyramid)[:self.fusion_pyramid_levels] # We multiply the flows with t and 1-t to warp to the desired fractional time. # # Note: In film_net we fix time to be 0.5, and recursively invoke the interpo- # lator for multi-frame interpolation. Below, we create a constant tensor of # shape [B]. We use the `time` tensor to infer the batch size. backward_flow = util.multiply_pyramid(backward_flow_pyramid, batch_dt) forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - batch_dt) pyramids_to_warp = [ util.concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels], feature_pyramids[0][:self.fusion_pyramid_levels]), util.concatenate_pyramids(image_pyramids[1][:self.fusion_pyramid_levels], feature_pyramids[1][:self.fusion_pyramid_levels]) ] # Warp features and images using the flow. Note that we use backward warping # and backward flow is used to read from image 0 and forward flow from # image 1. forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow) backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow) aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid, backward_warped_pyramid) aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow) aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow) return { 'image': [self.fuse(aligned_pyramid)], 'forward_residual_flow_pyramid': forward_residual_flow_pyramid, 'backward_residual_flow_pyramid': backward_residual_flow_pyramid, 'forward_flow_pyramid': forward_flow_pyramid, 'backward_flow_pyramid': backward_flow_pyramid, } def forward(self, x0, x1, batch_dt) -> torch.Tensor: return self.debug_forward(x0, x1, batch_dt)['image'][0]