GH29BERT / tape /utils /distributed_utils.py
KeXing
Upload 26 files
212111c
raw
history blame
5.95 kB
import typing
import argparse
import os
import multiprocessing as mp
import sys
import signal
import torch
import torch.distributed as dist
from torch.multiprocessing import _prctl_pr_set_pdeathsig # type: ignore
from ..errors import EarlyStopping
def reduce_scalar(scalar: float) -> float:
if dist.is_available() and dist.is_initialized():
float_tensor = torch.cuda.FloatTensor([scalar]) # type: ignore
dist.all_reduce(float_tensor)
float_tensor /= dist.get_world_size()
scalar = float_tensor.item()
return scalar
def barrier_if_distributed() -> None:
"""Raises a barrier if in a distributed context, otherwise does nothing."""
if dist.is_available() and dist.is_initialized():
dist.barrier()
def _wrap(fn, kwargs, error_queue):
# prctl(2) is a Linux specific system call.
# On other systems the following function call has no effect.
# This is set to ensure that non-daemonic child processes can
# terminate if their parent terminates before they do.
_prctl_pr_set_pdeathsig(signal.SIGINT)
try:
fn(**kwargs)
except KeyboardInterrupt:
pass # SIGINT; Killed by parent, do nothing
except EarlyStopping:
sys.exit(signal.SIGUSR1) # tape early stop exception
except Exception:
# Propagate exception to parent process, keeping original traceback
import traceback
error_queue.put(traceback.format_exc())
sys.exit(1)
class ProcessContext:
def __init__(self, processes, error_queues):
self.error_queues = error_queues
self.processes = processes
self.sentinels = {
process.sentinel: index
for index, process in enumerate(processes)
}
def pids(self):
return [int(process.pid) for process in self.processes]
def join(self, timeout=None):
r"""
Tries to join one or more processes in this process context.
If one of them exited with a non-zero exit status, this function
kills the remaining processes and raises an exception with the cause
of the first process exiting.
Returns ``True`` if all processes have been joined successfully,
``False`` if there are more processes that need to be joined.
Arguments:
timeout (float): Wait this long before giving up on waiting.
"""
# Ensure this function can be called even when we're done.
if len(self.sentinels) == 0:
return True
# Wait for any process to fail or all of them to succeed.
ready = mp.connection.wait(
self.sentinels.keys(),
timeout=timeout,
)
error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break
# Return if there was no error.
if error_index is None:
# Return whether or not all processes have been joined.
return len(self.sentinels) == 0
# Assume failure. Terminate processes that are still alive.
for process in self.processes:
if process.is_alive():
process.terminate()
process.join()
# There won't be an error on the queue if the process crashed.
if self.error_queues[error_index].empty():
exitcode = self.processes[error_index].exitcode
if exitcode == signal.SIGUSR1:
return True
elif exitcode < 0:
name = signal.Signals(-exitcode).name
raise Exception(
"process %d terminated with signal %s" %
(error_index, name)
)
else:
raise Exception(
"process %d terminated with exit code %d" %
(error_index, exitcode)
)
original_trace = self.error_queues[error_index].get()
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
msg += original_trace
raise Exception(msg)
def launch_process_group(func: typing.Callable,
args: argparse.Namespace,
num_processes: int,
num_nodes: int = 1,
node_rank: int = 0,
master_addr: str = "127.0.0.1",
master_port: int = 29500,
join: bool = True,
daemon: bool = False):
# world size in terms of number of processes
dist_world_size = num_processes * num_nodes
# set PyTorch distributed related environmental variables
current_env = os.environ.copy()
current_env["MASTER_ADDR"] = master_addr
current_env["MASTER_PORT"] = str(master_port)
current_env["WORLD_SIZE"] = str(dist_world_size)
if 'OMP_NUM_THREADS' not in os.environ and num_processes > 1:
current_env["OMP_NUM_THREADS"] = str(4)
error_queues = []
processes = []
for local_rank in range(num_processes):
# each process's rank
dist_rank = num_processes * node_rank + local_rank
current_env["RANK"] = str(dist_rank)
current_env["LOCAL_RANK"] = str(local_rank)
args.local_rank = local_rank
error_queue: mp.SimpleQueue[Exception] = mp.SimpleQueue()
kwargs = {'args': args, 'env': current_env}
process = mp.Process(
target=_wrap,
args=(func, kwargs, error_queue),
daemon=daemon)
process.start()
error_queues.append(error_queue)
processes.append(process)
process_context = ProcessContext(processes, error_queues)
if not join:
return process_context
while not process_context.join():
pass