|
|
|
import importlib
|
|
import os
|
|
import pkgutil
|
|
import warnings
|
|
from collections import namedtuple
|
|
|
|
import torch
|
|
|
|
if torch.__version__ != 'parrots':
|
|
|
|
def load_ext(name, funcs):
|
|
ext = importlib.import_module('mmcv.' + name)
|
|
for fun in funcs:
|
|
assert hasattr(ext, fun), f'{fun} miss in module {name}'
|
|
return ext
|
|
else:
|
|
from parrots import extension
|
|
from parrots.base import ParrotsException
|
|
|
|
has_return_value_ops = [
|
|
'nms',
|
|
'softnms',
|
|
'nms_match',
|
|
'nms_rotated',
|
|
'top_pool_forward',
|
|
'top_pool_backward',
|
|
'bottom_pool_forward',
|
|
'bottom_pool_backward',
|
|
'left_pool_forward',
|
|
'left_pool_backward',
|
|
'right_pool_forward',
|
|
'right_pool_backward',
|
|
'fused_bias_leakyrelu',
|
|
'upfirdn2d',
|
|
'ms_deform_attn_forward',
|
|
'pixel_group',
|
|
'contour_expand',
|
|
]
|
|
|
|
def get_fake_func(name, e):
|
|
|
|
def fake_func(*args, **kwargs):
|
|
warnings.warn(f'{name} is not supported in parrots now')
|
|
raise e
|
|
|
|
return fake_func
|
|
|
|
def load_ext(name, funcs):
|
|
ExtModule = namedtuple('ExtModule', funcs)
|
|
ext_list = []
|
|
lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
for fun in funcs:
|
|
try:
|
|
ext_fun = extension.load(fun, name, lib_dir=lib_root)
|
|
except ParrotsException as e:
|
|
if 'No element registered' not in e.message:
|
|
warnings.warn(e.message)
|
|
ext_fun = get_fake_func(fun, e)
|
|
ext_list.append(ext_fun)
|
|
else:
|
|
if fun in has_return_value_ops:
|
|
ext_list.append(ext_fun.op)
|
|
else:
|
|
ext_list.append(ext_fun.op_)
|
|
return ExtModule(*ext_list)
|
|
|
|
|
|
def check_ops_exist():
|
|
ext_loader = pkgutil.find_loader('mmcv._ext')
|
|
return ext_loader is not None
|
|
|