|
import typing |
|
import argparse |
|
import os |
|
import multiprocessing as mp |
|
import sys |
|
import signal |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.multiprocessing import _prctl_pr_set_pdeathsig |
|
|
|
from ..errors import EarlyStopping |
|
|
|
|
|
def reduce_scalar(scalar: float) -> float: |
|
if dist.is_available() and dist.is_initialized(): |
|
float_tensor = torch.cuda.FloatTensor([scalar]) |
|
dist.all_reduce(float_tensor) |
|
float_tensor /= dist.get_world_size() |
|
scalar = float_tensor.item() |
|
return scalar |
|
|
|
|
|
def barrier_if_distributed() -> None: |
|
"""Raises a barrier if in a distributed context, otherwise does nothing.""" |
|
if dist.is_available() and dist.is_initialized(): |
|
dist.barrier() |
|
|
|
|
|
def _wrap(fn, kwargs, error_queue): |
|
|
|
|
|
|
|
|
|
_prctl_pr_set_pdeathsig(signal.SIGINT) |
|
|
|
try: |
|
fn(**kwargs) |
|
except KeyboardInterrupt: |
|
pass |
|
except EarlyStopping: |
|
sys.exit(signal.SIGUSR1) |
|
except Exception: |
|
|
|
import traceback |
|
error_queue.put(traceback.format_exc()) |
|
sys.exit(1) |
|
|
|
|
|
class ProcessContext: |
|
def __init__(self, processes, error_queues): |
|
self.error_queues = error_queues |
|
self.processes = processes |
|
self.sentinels = { |
|
process.sentinel: index |
|
for index, process in enumerate(processes) |
|
} |
|
|
|
def pids(self): |
|
return [int(process.pid) for process in self.processes] |
|
|
|
def join(self, timeout=None): |
|
r""" |
|
Tries to join one or more processes in this process context. |
|
If one of them exited with a non-zero exit status, this function |
|
kills the remaining processes and raises an exception with the cause |
|
of the first process exiting. |
|
|
|
Returns ``True`` if all processes have been joined successfully, |
|
``False`` if there are more processes that need to be joined. |
|
|
|
Arguments: |
|
timeout (float): Wait this long before giving up on waiting. |
|
""" |
|
|
|
if len(self.sentinels) == 0: |
|
return True |
|
|
|
|
|
ready = mp.connection.wait( |
|
self.sentinels.keys(), |
|
timeout=timeout, |
|
) |
|
error_index = None |
|
for sentinel in ready: |
|
index = self.sentinels.pop(sentinel) |
|
process = self.processes[index] |
|
process.join() |
|
if process.exitcode != 0: |
|
error_index = index |
|
break |
|
|
|
if error_index is None: |
|
|
|
return len(self.sentinels) == 0 |
|
|
|
for process in self.processes: |
|
if process.is_alive(): |
|
process.terminate() |
|
process.join() |
|
|
|
|
|
if self.error_queues[error_index].empty(): |
|
exitcode = self.processes[error_index].exitcode |
|
if exitcode == signal.SIGUSR1: |
|
return True |
|
elif exitcode < 0: |
|
name = signal.Signals(-exitcode).name |
|
raise Exception( |
|
"process %d terminated with signal %s" % |
|
(error_index, name) |
|
) |
|
else: |
|
raise Exception( |
|
"process %d terminated with exit code %d" % |
|
(error_index, exitcode) |
|
) |
|
|
|
original_trace = self.error_queues[error_index].get() |
|
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index |
|
msg += original_trace |
|
raise Exception(msg) |
|
|
|
|
|
def launch_process_group(func: typing.Callable, |
|
args: argparse.Namespace, |
|
num_processes: int, |
|
num_nodes: int = 1, |
|
node_rank: int = 0, |
|
master_addr: str = "127.0.0.1", |
|
master_port: int = 29500, |
|
join: bool = True, |
|
daemon: bool = False): |
|
|
|
dist_world_size = num_processes * num_nodes |
|
|
|
|
|
current_env = os.environ.copy() |
|
current_env["MASTER_ADDR"] = master_addr |
|
current_env["MASTER_PORT"] = str(master_port) |
|
current_env["WORLD_SIZE"] = str(dist_world_size) |
|
if 'OMP_NUM_THREADS' not in os.environ and num_processes > 1: |
|
current_env["OMP_NUM_THREADS"] = str(4) |
|
|
|
error_queues = [] |
|
processes = [] |
|
|
|
for local_rank in range(num_processes): |
|
|
|
dist_rank = num_processes * node_rank + local_rank |
|
current_env["RANK"] = str(dist_rank) |
|
current_env["LOCAL_RANK"] = str(local_rank) |
|
args.local_rank = local_rank |
|
|
|
error_queue: mp.SimpleQueue[Exception] = mp.SimpleQueue() |
|
kwargs = {'args': args, 'env': current_env} |
|
process = mp.Process( |
|
target=_wrap, |
|
args=(func, kwargs, error_queue), |
|
daemon=daemon) |
|
process.start() |
|
error_queues.append(error_queue) |
|
processes.append(process) |
|
|
|
process_context = ProcessContext(processes, error_queues) |
|
if not join: |
|
return process_context |
|
|
|
while not process_context.join(): |
|
pass |
|
|