is_virtualstaging / fastai /distributed.py
AItool's picture
Upload 127 files
a983ebc
raw
history blame
No virus
9.78 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/20a_distributed.ipynb.
# %% ../nbs/20a_distributed.ipynb 2
from __future__ import annotations
from .basics import *
from .callback.progress import ProgressCallback
from torch.nn.parallel import DistributedDataParallel, DataParallel
from .data.load import _FakeLoader,_loaders
from .optimizer import OptimWrapper
try: from accelerate import Accelerator
except ModuleNotFoundError: pass
# %% auto 0
__all__ = ['ParallelTrainer', 'setup_distrib', 'teardown_distrib', 'DistributedDL', 'DistributedTrainer', 'rank0_first']
# %% ../nbs/20a_distributed.ipynb 6
@patch
def reset(self: DataParallel):
"Patch required `reset` call into `DataParallel`"
if hasattr(self.module, 'reset'): self.module.reset()
# %% ../nbs/20a_distributed.ipynb 7
class ParallelTrainer(Callback):
"Wrap a model `DataParallel` automatically"
run_after,run_before = TrainEvalCallback,Recorder
def __init__(self, device_ids): self.device_ids = device_ids
def before_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)
def after_fit(self): self.learn.model = self.learn.model.module
# %% ../nbs/20a_distributed.ipynb 8
@patch
def to_parallel(self: Learner, device_ids=None):
"Add `ParallelTrainer` callback to a `Learner`"
self.add_cb(ParallelTrainer(device_ids))
return self
# %% ../nbs/20a_distributed.ipynb 9
@patch
def detach_parallel(self: Learner):
"Remove `ParallelTrainer` callback from a Learner"
self.remove_cb(ParallelTrainer)
return self
# %% ../nbs/20a_distributed.ipynb 10
@patch
@contextmanager
def parallel_ctx(self: Learner, device_ids=None):
"A context manager to adapt a learner to train in data parallel mode."
try:
self.to_parallel(device_ids)
yield self
finally: self.detach_parallel()
# %% ../nbs/20a_distributed.ipynb 13
@patch
def reset(self: DistributedDataParallel):
"Patch required `reset` call into `DistributedDataParallel`"
if hasattr(self.module, 'reset'): self.module.reset()
# %% ../nbs/20a_distributed.ipynb 14
def setup_distrib(gpu=None):
"Setup this process to participate in distributed training"
if gpu is None: return gpu
gpu = int(gpu)
torch.cuda.set_device(int(gpu))
if num_distrib() > 0: torch.distributed.init_process_group(backend='nccl', init_method='env://')
return gpu
# %% ../nbs/20a_distributed.ipynb 15
def teardown_distrib():
"Free distributed training resources"
if torch.distributed.is_initialized(): torch.distributed.destroy_process_group()
# %% ../nbs/20a_distributed.ipynb 17
def _round_to_multiple(number,multiple): return int(math.ceil(number/multiple)*multiple)
# %% ../nbs/20a_distributed.ipynb 18
class DistributedDL(TfmdDL):
"A `TfmdDL` which splits a batch into equal size pieces for each worker"
def __init__(self,dl,rank=None,world_size=None):
if rank is None: rank=rank_distrib()
if world_size is None: world_size=num_distrib()
store_attr()
if type(dl) == torch.utils.data.DataLoader:
shuffle = True if eq(type(dl.sampler), torch.utils.data.RandomSampler) else False
self.dl = DataLoader(dataset=dl.dataset, bs=dl.batch_size, num_workers=dl.num_workers, \
pin_memory=dl.pin_memory, timeout=dl.timeout, shuffle=shuffle, drop_last=dl.drop_last, persistent_workers=dl.persistent_workers)
self.bs,self.device,self.drop_last,self.dataset,fake,self.num_workers,self.offs,self.pin_memory = \
attrgetter('bs','device','drop_last','dataset','fake_l','num_workers','offs','pin_memory')(self.dl)
self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout,
persistent_workers=fake.persistent_workers,
pin_memory_device=fake.pin_memory_device)
def _broadcast(self,t,rank):
"Broadcasts t from rank `rank` to all other ranks. Returns t so t is same for all ranks after call."
t = LongTensor(t).cuda() # nccl only works with cuda tensors
torch.distributed.broadcast(t,rank)
return t.cpu().tolist()
def _to_detach(self,b,cpu=True,gather=True): return to_detach(b,cpu,gather) # member func so we can override for test
def __len__(self): return _round_to_multiple(len(self.dl),self.world_size)//self.world_size
def get_idxs(self):
idxs = list(self.dl.get_idxs()) # compute get_idxs in all ranks (we'll only use rank 0 but size must be consistent)
idxs = self._broadcast(idxs,0) # broadcast and receive it from rank 0 to all
self.n = len(idxs) # we assumed n was dl.n but we really care about number of idxs
# add extra samples to make it evenly divisible
self.n_padded = _round_to_multiple(self.n,self.world_size)
idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n] # idx needs to be repeated when n_padded>>n
# slice padded idxs so that each rank gets self.n_padded//self.world_size tensors
return idxs[self.rank*self.n_padded//self.world_size:(self.rank+1)*self.n_padded//self.world_size]
def before_iter(self):
self.i = 0
self.dl.before_iter()
def randomize(self): self.dl.randomize()
def after_batch(self,b):
self.i += find_bs(b)
return self.dl.after_batch(b)
def after_iter(self): self.dl.after_iter()
def create_batches(self,samps): return self.dl.create_batches(samps)
def to_detach(self,b, cpu=True, gather=True):
b = self._to_detach(b, cpu, gather)
def _inner(b):
if b.ndim>0:
# for each rank, compute overflow of read idxs vs self.n and accumulate them to unpad totals after gathering
n = sum([min(0,max(-len(b)//self.world_size,
self.n-(self.i+r*self.n_padded//self.world_size))) for r in range(self.world_size)])
b = b[:n or None]
return b
return apply(_inner,b) if gather and all(hasattr(self,o) for o in ('i','n','n_padded')) else b
# %% ../nbs/20a_distributed.ipynb 29
_hidden_params = ["mixed_precision", "fp16", "log_with", "logging_dir", "step_scheduler_with_optimizer"]
# %% ../nbs/20a_distributed.ipynb 30
class DistributedTrainer(Callback):
"Wrap `model` in `DistributedDataParallel` and `dls` in `DistributedDL`"
order = 11
@delegates(Accelerator, but=_hidden_params)
def __init__(self,
sync_bn=True, # Whether to replace all batch norm with `nn.SyncBatchNorm`
**kwargs
):
store_attr()
self.accelerator = Accelerator(**kwargs)
def before_fit(self):
self.learn.model = self.accelerator.prepare(
nn.SyncBatchNorm.convert_sync_batchnorm(self.model) if self.sync_bn else self.model
)
self.old_dls = list(self.dls)
self.learn.dls.loaders = [self._wrap_dl(dl) for dl in self.dls]
if rank_distrib(): self.learn.logger=noop
def _wrap_dl(self, dl): return dl if isinstance(dl,DistributedDL) else DistributedDL(dl)
def _backward(self): self.accelerator.backward(self.learn.loss_grad)
def before_train(self): self.learn.dl = self._wrap_dl(self.learn.dl)
def before_validate(self): self.learn.dl = self._wrap_dl(self.learn.dl)
def after_fit(self): self.learn.model,self.learn.dls.loaders = self.learn.model.module,self.old_dls
# %% ../nbs/20a_distributed.ipynb 31
@patch
@delegates(Accelerator, but=_hidden_params)
def to_distributed(self: Learner,
sync_bn=True, # Whether to replace all batch norm with `nn.SyncBatchNorm`
**kwargs
):
"Add `AcceleratedTrainer` to a learner, and configures an Accelerator"
self.add_cb(DistributedTrainer(sync_bn, **kwargs))
if rank_distrib(): self.remove_cb(ProgressCallback)
return self
# %% ../nbs/20a_distributed.ipynb 32
@patch
def detach_distributed(self: Learner):
"Remove `DistributedTrainer` from a learner"
if num_distrib() <=1: return self
self.remove_cb(DistributedTrainer)
if rank_distrib() and not hasattr(self, 'progress'): self.add_cb(ProgressCallback())
return self
# %% ../nbs/20a_distributed.ipynb 34
@patch
@contextmanager
@delegates(Accelerator, but=_hidden_params)
def distrib_ctx(self: Learner,
sync_bn=True, # Whether to replace all batch norm with `nn.SyncBatchNorm`
in_notebook=False, # Whether we are launching from a notebook or not
**kwargs
):
"A context manager to adapt a learner to train in distributed data parallel mode."
try: import accelerate
except ImportError as e:
e.args = ["Accelerate is required. Install with `pip install accelerate`"]
raise
# Adapt self to DistributedDataParallel, yield, and cleanup afterwards.
cleanup_dpg = False
try:
if in_notebook:
cuda_id = rank_distrib()
if not torch.distributed.is_initialized():
setup_distrib(cuda_id)
cleanup_dpg = torch.distributed.is_initialized()
if not rank_distrib(): print("Training Learner...")
if num_distrib(): self.to_distributed(sync_bn, **kwargs)
yield self
finally:
self.detach_distributed()
if cleanup_dpg: teardown_distrib()
# %% ../nbs/20a_distributed.ipynb 36
def rank0_first(func, *args, **kwargs):
"Execute `func` in the Rank-0 process first, then in other ranks in parallel."
if args or kwargs: func = partial(func, *args, **kwargs)
dummy_l = Learner(DataLoaders(device='cpu'), nn.Linear(1,1), loss_func=lambda: 0)
with dummy_l.distrib_ctx():
if not rank_distrib(): res = func()
distrib_barrier()
if rank_distrib(): res = func()
return res