Inpaint / src /util /data_loader.py
ZehanWang's picture
Upload folder using huggingface_hub
864ec44 verified
raw
history blame
3.61 kB
# Copied from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
# kwargs of the DataLoader in min version 1.4.0.
_PYTORCH_DATALOADER_KWARGS = {
"batch_size": 1,
"shuffle": False,
"sampler": None,
"batch_sampler": None,
"num_workers": 0,
"collate_fn": None,
"pin_memory": False,
"drop_last": False,
"timeout": 0,
"worker_init_fn": None,
"multiprocessing_context": None,
"generator": None,
"prefetch_factor": 2,
"persistent_workers": False,
}
class SkipBatchSampler(BatchSampler):
"""
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
"""
def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
self.skip_batches = skip_batches
def __iter__(self):
for index, samples in enumerate(self.batch_sampler):
if index >= self.skip_batches:
yield samples
@property
def total_length(self):
return len(self.batch_sampler)
def __len__(self):
return len(self.batch_sampler) - self.skip_batches
class SkipDataLoader(DataLoader):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.
Args:
dataset (`torch.utils.data.dataset.Dataset`):
The dataset to use to build this datalaoder.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
kwargs:
All other keyword arguments to pass to the regular `DataLoader` initialization.
"""
def __init__(self, dataset, skip_batches=0, **kwargs):
super().__init__(dataset, **kwargs)
self.skip_batches = skip_batches
def __iter__(self):
for index, batch in enumerate(super().__iter__()):
if index >= self.skip_batches:
yield batch
# Adapted from https://github.com/huggingface/accelerate
def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
"""
dataset = dataloader.dataset
sampler_is_batch_sampler = False
if isinstance(dataset, IterableDataset):
new_batch_sampler = None
else:
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
batch_sampler = (
dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
)
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
# We ignore all of those since they are all dealt with by our new_batch_sampler
ignore_kwargs = [
"batch_size",
"shuffle",
"sampler",
"batch_sampler",
"drop_last",
]
kwargs = {
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
for k in _PYTORCH_DATALOADER_KWARGS
if k not in ignore_kwargs
}
# Need to provide batch_size as batch_sampler is None for Iterable dataset
if new_batch_sampler is None:
kwargs["drop_last"] = dataloader.drop_last
kwargs["batch_size"] = dataloader.batch_size
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
return dataloader