# Copied from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py from torch.utils.data import BatchSampler, DataLoader, IterableDataset # kwargs of the DataLoader in min version 1.4.0. _PYTORCH_DATALOADER_KWARGS = { "batch_size": 1, "shuffle": False, "sampler": None, "batch_sampler": None, "num_workers": 0, "collate_fn": None, "pin_memory": False, "drop_last": False, "timeout": 0, "worker_init_fn": None, "multiprocessing_context": None, "generator": None, "prefetch_factor": 2, "persistent_workers": False, } class SkipBatchSampler(BatchSampler): """ A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. """ def __init__(self, batch_sampler, skip_batches=0): self.batch_sampler = batch_sampler self.skip_batches = skip_batches def __iter__(self): for index, samples in enumerate(self.batch_sampler): if index >= self.skip_batches: yield samples @property def total_length(self): return len(self.batch_sampler) def __len__(self): return len(self.batch_sampler) - self.skip_batches class SkipDataLoader(DataLoader): """ Subclass of a PyTorch `DataLoader` that will skip the first batches. Args: dataset (`torch.utils.data.dataset.Dataset`): The dataset to use to build this datalaoder. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. kwargs: All other keyword arguments to pass to the regular `DataLoader` initialization. """ def __init__(self, dataset, skip_batches=0, **kwargs): super().__init__(dataset, **kwargs) self.skip_batches = skip_batches def __iter__(self): for index, batch in enumerate(super().__iter__()): if index >= self.skip_batches: yield batch # Adapted from https://github.com/huggingface/accelerate def skip_first_batches(dataloader, num_batches=0): """ Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. """ dataset = dataloader.dataset sampler_is_batch_sampler = False if isinstance(dataset, IterableDataset): new_batch_sampler = None else: sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) batch_sampler = ( dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler ) new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) # We ignore all of those since they are all dealt with by our new_batch_sampler ignore_kwargs = [ "batch_size", "shuffle", "sampler", "batch_sampler", "drop_last", ] kwargs = { k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS if k not in ignore_kwargs } # Need to provide batch_size as batch_sampler is None for Iterable dataset if new_batch_sampler is None: kwargs["drop_last"] = dataloader.drop_last kwargs["batch_size"] = dataloader.batch_size if new_batch_sampler is None: # Need to manually skip batches in the dataloader dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) else: dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) return dataloader