"""Various utilities used in the film_net frame interpolator model.""" from typing import List, Optional import cv2 import numpy as np import torch from torch import nn from torch.nn import functional as F def pad_batch(batch, align): height, width = batch.shape[1:3] height_to_pad = (align - height % align) if height % align != 0 else 0 width_to_pad = (align - width % align) if width % align != 0 else 0 crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)] batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant') return batch, crop_region def load_image(path, align=64): image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align) return image_batch, crop_region def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]: """Builds an image pyramid from a given image. The original image is included in the pyramid and the rest are generated by successively halving the resolution. Args: image: the input image. options: film_net options object Returns: A list of images starting from the finest with options.pyramid_levels items """ pyramid = [] for i in range(pyramid_levels): pyramid.append(image) if i < pyramid_levels - 1: image = F.avg_pool2d(image, 2, 2) return pyramid def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: """Backward warps the image using the given flow. Specifically, the output pixel in batch b, at position x, y will be computed as follows: (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0]) output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x) Note that the flow vectors are expected as [x, y], e.g. x in position 0 and y in position 1. Args: image: An image with shape BxHxWxC. flow: A flow with shape BxHxWx2, with the two channels denoting the relative offset in order: (dx, dy). Returns: A warped image. """ flow = -flow.flip(1) dtype = flow.dtype device = flow.device # warped = tfa_image.dense_image_warp(image, flow) # Same as above but with pytorch ls1 = 1 - 1 / flow.shape[3] ls2 = 1 - 1 / flow.shape[2] normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor( [flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None] normalized_flow2 = torch.stack([ torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1], torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0], ], dim=3) warped = F.grid_sample(image, normalized_flow2, mode='bilinear', padding_mode='border', align_corners=False) return warped.reshape(image.shape) def multiply_pyramid(pyramid: List[torch.Tensor], scalar: torch.Tensor) -> List[torch.Tensor]: """Multiplies all image batches in the pyramid by a batch of scalars. Args: pyramid: Pyramid of image batches. scalar: Batch of scalars. Returns: An image pyramid with all images multiplied by the scalar. """ # To multiply each image with its corresponding scalar, we first transpose # the batch of images from BxHxWxC-format to CxHxWxB. This can then be # multiplied with a batch of scalars, then we transpose back to the standard # BxHxWxC form. return [image * scalar[..., None, None] for image in pyramid] def flow_pyramid_synthesis( residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: """Converts a residual flow pyramid into a flow pyramid.""" flow = residual_pyramid[-1] flow_pyramid: List[torch.Tensor] = [flow] for residual_flow in residual_pyramid[:-1][::-1]: level_size = residual_flow.shape[2:4] flow = F.interpolate(2 * flow, size=level_size, mode='bilinear') flow = residual_flow + flow flow_pyramid.insert(0, flow) return flow_pyramid def pyramid_warp(feature_pyramid: List[torch.Tensor], flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: """Warps the feature pyramid using the flow pyramid. Args: feature_pyramid: feature pyramid starting from the finest level. flow_pyramid: flow fields, starting from the finest level. Returns: Reverse warped feature pyramid. """ warped_feature_pyramid = [] for features, flow in zip(feature_pyramid, flow_pyramid): warped_feature_pyramid.append(warp(features, flow)) return warped_feature_pyramid def concatenate_pyramids(pyramid1: List[torch.Tensor], pyramid2: List[torch.Tensor]) -> List[torch.Tensor]: """Concatenates each pyramid level together in the channel dimension.""" result = [] for features1, features2 in zip(pyramid1, pyramid2): result.append(torch.cat([features1, features2], dim=1)) return result class Conv2d(nn.Sequential): def __init__(self, in_channels, out_channels, size, activation: Optional[str] = 'relu'): assert activation in (None, 'relu') super().__init__( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=size, padding='same' if size % 2 else 0) ) self.size = size self.activation = nn.LeakyReLU(.2) if activation == 'relu' else None def forward(self, x): if not self.size % 2: x = F.pad(x, (0, 1, 0, 1)) y = self[0](x) if self.activation is not None: y = self.activation(y) return y