File size: 5,946 Bytes
212111c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import typing
import argparse
import os
import multiprocessing as mp
import sys
import signal

import torch
import torch.distributed as dist
from torch.multiprocessing import _prctl_pr_set_pdeathsig  # type: ignore

from ..errors import EarlyStopping


def reduce_scalar(scalar: float) -> float:
    if dist.is_available() and dist.is_initialized():
        float_tensor = torch.cuda.FloatTensor([scalar])  # type: ignore
        dist.all_reduce(float_tensor)
        float_tensor /= dist.get_world_size()
        scalar = float_tensor.item()
    return scalar


def barrier_if_distributed() -> None:
    """Raises a barrier if in a distributed context, otherwise does nothing."""
    if dist.is_available() and dist.is_initialized():
        dist.barrier()


def _wrap(fn, kwargs, error_queue):
    # prctl(2) is a Linux specific system call.
    # On other systems the following function call has no effect.
    # This is set to ensure that non-daemonic child processes can
    # terminate if their parent terminates before they do.
    _prctl_pr_set_pdeathsig(signal.SIGINT)

    try:
        fn(**kwargs)
    except KeyboardInterrupt:
        pass  # SIGINT; Killed by parent, do nothing
    except EarlyStopping:
        sys.exit(signal.SIGUSR1)  # tape early stop exception
    except Exception:
        # Propagate exception to parent process, keeping original traceback
        import traceback
        error_queue.put(traceback.format_exc())
        sys.exit(1)


class ProcessContext:
    def __init__(self, processes, error_queues):
        self.error_queues = error_queues
        self.processes = processes
        self.sentinels = {
            process.sentinel: index
            for index, process in enumerate(processes)
        }

    def pids(self):
        return [int(process.pid) for process in self.processes]

    def join(self, timeout=None):
        r"""
        Tries to join one or more processes in this process context.
        If one of them exited with a non-zero exit status, this function
        kills the remaining processes and raises an exception with the cause
        of the first process exiting.

        Returns ``True`` if all processes have been joined successfully,
        ``False`` if there are more processes that need to be joined.

        Arguments:
            timeout (float): Wait this long before giving up on waiting.
        """
        # Ensure this function can be called even when we're done.
        if len(self.sentinels) == 0:
            return True

        # Wait for any process to fail or all of them to succeed.
        ready = mp.connection.wait(
            self.sentinels.keys(),
            timeout=timeout,
        )
        error_index = None
        for sentinel in ready:
            index = self.sentinels.pop(sentinel)
            process = self.processes[index]
            process.join()
            if process.exitcode != 0:
                error_index = index
                break
        # Return if there was no error.
        if error_index is None:
            # Return whether or not all processes have been joined.
            return len(self.sentinels) == 0
        # Assume failure. Terminate processes that are still alive.
        for process in self.processes:
            if process.is_alive():
                process.terminate()
            process.join()

        # There won't be an error on the queue if the process crashed.
        if self.error_queues[error_index].empty():
            exitcode = self.processes[error_index].exitcode
            if exitcode == signal.SIGUSR1:
                return True
            elif exitcode < 0:
                name = signal.Signals(-exitcode).name
                raise Exception(
                    "process %d terminated with signal %s" %
                    (error_index, name)
                )
            else:
                raise Exception(
                    "process %d terminated with exit code %d" %
                    (error_index, exitcode)
                )

        original_trace = self.error_queues[error_index].get()
        msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
        msg += original_trace
        raise Exception(msg)


def launch_process_group(func: typing.Callable,
                         args: argparse.Namespace,
                         num_processes: int,
                         num_nodes: int = 1,
                         node_rank: int = 0,
                         master_addr: str = "127.0.0.1",
                         master_port: int = 29500,
                         join: bool = True,
                         daemon: bool = False):
    # world size in terms of number of processes
    dist_world_size = num_processes * num_nodes

    # set PyTorch distributed related environmental variables
    current_env = os.environ.copy()
    current_env["MASTER_ADDR"] = master_addr
    current_env["MASTER_PORT"] = str(master_port)
    current_env["WORLD_SIZE"] = str(dist_world_size)
    if 'OMP_NUM_THREADS' not in os.environ and num_processes > 1:
        current_env["OMP_NUM_THREADS"] = str(4)

    error_queues = []
    processes = []

    for local_rank in range(num_processes):
        # each process's rank
        dist_rank = num_processes * node_rank + local_rank
        current_env["RANK"] = str(dist_rank)
        current_env["LOCAL_RANK"] = str(local_rank)
        args.local_rank = local_rank

        error_queue: mp.SimpleQueue[Exception] = mp.SimpleQueue()
        kwargs = {'args': args, 'env': current_env}
        process = mp.Process(
            target=_wrap,
            args=(func, kwargs, error_queue),
            daemon=daemon)
        process.start()
        error_queues.append(error_queue)
        processes.append(process)

    process_context = ProcessContext(processes, error_queues)
    if not join:
        return process_context

    while not process_context.join():
        pass