|
|
|
from functools import partial
|
|
|
|
import torch
|
|
|
|
TORCH_VERSION = torch.__version__
|
|
|
|
|
|
def is_rocm_pytorch() -> bool:
|
|
is_rocm = False
|
|
if TORCH_VERSION != 'parrots':
|
|
try:
|
|
from torch.utils.cpp_extension import ROCM_HOME
|
|
is_rocm = True if ((torch.version.hip is not None) and
|
|
(ROCM_HOME is not None)) else False
|
|
except ImportError:
|
|
pass
|
|
return is_rocm
|
|
|
|
|
|
def _get_cuda_home():
|
|
if TORCH_VERSION == 'parrots':
|
|
from parrots.utils.build_extension import CUDA_HOME
|
|
else:
|
|
if is_rocm_pytorch():
|
|
from torch.utils.cpp_extension import ROCM_HOME
|
|
CUDA_HOME = ROCM_HOME
|
|
else:
|
|
from torch.utils.cpp_extension import CUDA_HOME
|
|
return CUDA_HOME
|
|
|
|
|
|
def get_build_config():
|
|
if TORCH_VERSION == 'parrots':
|
|
from parrots.config import get_build_info
|
|
return get_build_info()
|
|
else:
|
|
return torch.__config__.show()
|
|
|
|
|
|
def _get_conv():
|
|
if TORCH_VERSION == 'parrots':
|
|
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
|
|
else:
|
|
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
|
|
return _ConvNd, _ConvTransposeMixin
|
|
|
|
|
|
def _get_dataloader():
|
|
if TORCH_VERSION == 'parrots':
|
|
from torch.utils.data import DataLoader, PoolDataLoader
|
|
else:
|
|
from torch.utils.data import DataLoader
|
|
PoolDataLoader = DataLoader
|
|
return DataLoader, PoolDataLoader
|
|
|
|
|
|
def _get_extension():
|
|
if TORCH_VERSION == 'parrots':
|
|
from parrots.utils.build_extension import BuildExtension, Extension
|
|
CppExtension = partial(Extension, cuda=False)
|
|
CUDAExtension = partial(Extension, cuda=True)
|
|
else:
|
|
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
|
|
CUDAExtension)
|
|
return BuildExtension, CppExtension, CUDAExtension
|
|
|
|
|
|
def _get_pool():
|
|
if TORCH_VERSION == 'parrots':
|
|
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
|
|
_AdaptiveMaxPoolNd, _AvgPoolNd,
|
|
_MaxPoolNd)
|
|
else:
|
|
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
|
|
_AdaptiveMaxPoolNd, _AvgPoolNd,
|
|
_MaxPoolNd)
|
|
return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
|
|
|
|
|
|
def _get_norm():
|
|
if TORCH_VERSION == 'parrots':
|
|
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
|
|
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
|
|
else:
|
|
from torch.nn.modules.instancenorm import _InstanceNorm
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
SyncBatchNorm_ = torch.nn.SyncBatchNorm
|
|
return _BatchNorm, _InstanceNorm, SyncBatchNorm_
|
|
|
|
|
|
_ConvNd, _ConvTransposeMixin = _get_conv()
|
|
DataLoader, PoolDataLoader = _get_dataloader()
|
|
BuildExtension, CppExtension, CUDAExtension = _get_extension()
|
|
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
|
|
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
|
|
|
|
|
|
class SyncBatchNorm(SyncBatchNorm_):
|
|
|
|
def _check_input_dim(self, input):
|
|
if TORCH_VERSION == 'parrots':
|
|
if input.dim() < 2:
|
|
raise ValueError(
|
|
f'expected at least 2D input (got {input.dim()}D input)')
|
|
else:
|
|
super()._check_input_dim(input)
|
|
|