# -*- coding: utf-8 -*- import importlib import torch import torch.distributed as dist def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def get_obj_from_config(config): if "target" not in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"]) def instantiate_from_config(config, **kwargs): if "target" not in config: raise KeyError("Expected key `target` to instantiate.") cls = get_obj_from_str(config["target"]) params = config.get("params", dict()) # params.update(kwargs) # instance = cls(**params) kwargs.update(params) instance = cls(**kwargs) return instance def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def all_gather_batch(tensors): """ Performs all_gather operation on the provided tensors. """ # Queue the gathered tensors world_size = get_world_size() # There is no need for reduction in the single-proc case if world_size == 1: return tensors tensor_list = [] output_tensor = [] for tensor in tensors: tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] dist.all_gather( tensor_all, tensor, async_op=False # performance opt ) tensor_list.append(tensor_all) for tensor_all in tensor_list: output_tensor.append(torch.cat(tensor_all, dim=0)) return output_tensor