import typing import argparse import os import multiprocessing as mp import sys import signal import torch import torch.distributed as dist from torch.multiprocessing import _prctl_pr_set_pdeathsig # type: ignore from ..errors import EarlyStopping def reduce_scalar(scalar: float) -> float: if dist.is_available() and dist.is_initialized(): float_tensor = torch.cuda.FloatTensor([scalar]) # type: ignore dist.all_reduce(float_tensor) float_tensor /= dist.get_world_size() scalar = float_tensor.item() return scalar def barrier_if_distributed() -> None: """Raises a barrier if in a distributed context, otherwise does nothing.""" if dist.is_available() and dist.is_initialized(): dist.barrier() def _wrap(fn, kwargs, error_queue): # prctl(2) is a Linux specific system call. # On other systems the following function call has no effect. # This is set to ensure that non-daemonic child processes can # terminate if their parent terminates before they do. _prctl_pr_set_pdeathsig(signal.SIGINT) try: fn(**kwargs) except KeyboardInterrupt: pass # SIGINT; Killed by parent, do nothing except EarlyStopping: sys.exit(signal.SIGUSR1) # tape early stop exception except Exception: # Propagate exception to parent process, keeping original traceback import traceback error_queue.put(traceback.format_exc()) sys.exit(1) class ProcessContext: def __init__(self, processes, error_queues): self.error_queues = error_queues self.processes = processes self.sentinels = { process.sentinel: index for index, process in enumerate(processes) } def pids(self): return [int(process.pid) for process in self.processes] def join(self, timeout=None): r""" Tries to join one or more processes in this process context. If one of them exited with a non-zero exit status, this function kills the remaining processes and raises an exception with the cause of the first process exiting. Returns ``True`` if all processes have been joined successfully, ``False`` if there are more processes that need to be joined. Arguments: timeout (float): Wait this long before giving up on waiting. """ # Ensure this function can be called even when we're done. if len(self.sentinels) == 0: return True # Wait for any process to fail or all of them to succeed. ready = mp.connection.wait( self.sentinels.keys(), timeout=timeout, ) error_index = None for sentinel in ready: index = self.sentinels.pop(sentinel) process = self.processes[index] process.join() if process.exitcode != 0: error_index = index break # Return if there was no error. if error_index is None: # Return whether or not all processes have been joined. return len(self.sentinels) == 0 # Assume failure. Terminate processes that are still alive. for process in self.processes: if process.is_alive(): process.terminate() process.join() # There won't be an error on the queue if the process crashed. if self.error_queues[error_index].empty(): exitcode = self.processes[error_index].exitcode if exitcode == signal.SIGUSR1: return True elif exitcode < 0: name = signal.Signals(-exitcode).name raise Exception( "process %d terminated with signal %s" % (error_index, name) ) else: raise Exception( "process %d terminated with exit code %d" % (error_index, exitcode) ) original_trace = self.error_queues[error_index].get() msg = "\n\n-- Process %d terminated with the following error:\n" % error_index msg += original_trace raise Exception(msg) def launch_process_group(func: typing.Callable, args: argparse.Namespace, num_processes: int, num_nodes: int = 1, node_rank: int = 0, master_addr: str = "127.0.0.1", master_port: int = 29500, join: bool = True, daemon: bool = False): # world size in terms of number of processes dist_world_size = num_processes * num_nodes # set PyTorch distributed related environmental variables current_env = os.environ.copy() current_env["MASTER_ADDR"] = master_addr current_env["MASTER_PORT"] = str(master_port) current_env["WORLD_SIZE"] = str(dist_world_size) if 'OMP_NUM_THREADS' not in os.environ and num_processes > 1: current_env["OMP_NUM_THREADS"] = str(4) error_queues = [] processes = [] for local_rank in range(num_processes): # each process's rank dist_rank = num_processes * node_rank + local_rank current_env["RANK"] = str(dist_rank) current_env["LOCAL_RANK"] = str(local_rank) args.local_rank = local_rank error_queue: mp.SimpleQueue[Exception] = mp.SimpleQueue() kwargs = {'args': args, 'env': current_env} process = mp.Process( target=_wrap, args=(func, kwargs, error_queue), daemon=daemon) process.start() error_queues.append(error_queue) processes.append(process) process_context = ProcessContext(processes, error_queues) if not join: return process_context while not process_context.join(): pass