|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Function
|
|
|
|
from ..utils import ext_loader
|
|
|
|
ext_module = ext_loader.load_ext('_ext',
|
|
['tin_shift_forward', 'tin_shift_backward'])
|
|
|
|
|
|
class TINShiftFunction(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, shift):
|
|
C = input.size(2)
|
|
num_segments = shift.size(1)
|
|
if C // num_segments <= 0 or C % num_segments != 0:
|
|
raise ValueError('C should be a multiple of num_segments, '
|
|
f'but got C={C} and num_segments={num_segments}.')
|
|
|
|
ctx.save_for_backward(shift)
|
|
|
|
out = torch.zeros_like(input)
|
|
ext_module.tin_shift_forward(input, shift, out)
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
|
|
shift = ctx.saved_tensors[0]
|
|
data_grad_input = grad_output.new(*grad_output.size()).zero_()
|
|
shift_grad_input = shift.new(*shift.size()).zero_()
|
|
ext_module.tin_shift_backward(grad_output, shift, data_grad_input)
|
|
|
|
return data_grad_input, shift_grad_input
|
|
|
|
|
|
tin_shift = TINShiftFunction.apply
|
|
|
|
|
|
class TINShift(nn.Module):
|
|
"""Temporal Interlace Shift.
|
|
|
|
Temporal Interlace shift is a differentiable temporal-wise frame shifting
|
|
which is proposed in "Temporal Interlacing Network"
|
|
|
|
Please refer to https://arxiv.org/abs/2001.06499 for more details.
|
|
Code is modified from https://github.com/mit-han-lab/temporal-shift-module
|
|
"""
|
|
|
|
def forward(self, input, shift):
|
|
"""Perform temporal interlace shift.
|
|
|
|
Args:
|
|
input (Tensor): Feature map with shape [N, num_segments, C, H * W].
|
|
shift (Tensor): Shift tensor with shape [N, num_segments].
|
|
|
|
Returns:
|
|
Feature map after temporal interlace shift.
|
|
"""
|
|
return tin_shift(input, shift)
|
|
|