GH29BERT / tape /utils /_sampler.py
KeXing
Upload 26 files
212111c
raw
history blame
3.91 kB
"""Implementation of a bucketed data sampler from PyTorch-NLP.
Modified by Roshan Rao.
See https://github.com/PetrochukM/PyTorch-NLP/
"""
import typing
import math
import operator
from torch.utils.data.sampler import Sampler
from torch.utils.data.sampler import BatchSampler
from torch.utils.data.sampler import SubsetRandomSampler
class SortedSampler(Sampler):
""" Samples elements sequentially, always in the same order.
Args:
data (iterable): Iterable data.
sort_key (callable): Specifies a function of one argument that is used to extract a
numerical comparison key from each list element.
Example:
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
"""
def __init__(self,
dataset,
sort_key: typing.Callable[[int], typing.Any],
indices: typing.Optional[typing.Iterable[int]] = None):
super().__init__(dataset)
self.dataset = dataset
self.sort_key = sort_key
if indices is None:
sort_keys = map(sort_key, dataset)
else:
sort_keys = ((i, sort_key(dataset[i])) for i in indices)
self.sorted_indices = [i for i, _ in sorted(sort_keys, key=operator.itemgetter(1))]
def __iter__(self):
return iter(self.sorted_indices)
def __len__(self):
return len(self.dataset)
class BucketBatchSampler(BatchSampler):
""" `BucketBatchSampler` toggles between `sampler` batches and sorted batches.
Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between
random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted
and vice versa. Provides ~10-25 percent speedup.
Background:
``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular
libraries like ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together
examples with a similar size length to reduce the padding required for each batch
while maintaining some noise through bucketing.
Args:
sampler (torch.data.utils.sampler.Sampler):
batch_size (int): Size of mini-batch.
drop_last (bool): If `True` the sampler will drop the last batch if its size
would be less than `batch_size`.
sort_key (callable, optional): Callable to specify a comparison key for sorting.
bucket_size_multiplier (int, optional): Buckets are of size
`batch_size * bucket_size_multiplier`.
Example:
>>> from torch.utils.data.sampler import SequentialSampler
>>> sampler = SequentialSampler(list(range(10)))
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False))
[[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]]
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self,
sampler,
batch_size,
drop_last,
sort_key,
dataset,
bucket_size_multiplier=100):
super().__init__(sampler, batch_size, drop_last)
self.sort_key = sort_key
self.dataset = dataset
self.bucket_sampler = BatchSampler(
sampler, min(batch_size * bucket_size_multiplier, len(sampler)), False)
def __iter__(self):
for bucket in self.bucket_sampler:
sorted_sampler = SortedSampler(self.dataset, self.sort_key, indices=bucket)
for batch in SubsetRandomSampler(
list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return math.ceil(len(self.sampler) / self.batch_size)