|
"""Implementation of a bucketed data sampler from PyTorch-NLP. |
|
Modified by Roshan Rao. |
|
|
|
See https://github.com/PetrochukM/PyTorch-NLP/ |
|
""" |
|
import typing |
|
import math |
|
import operator |
|
from torch.utils.data.sampler import Sampler |
|
from torch.utils.data.sampler import BatchSampler |
|
from torch.utils.data.sampler import SubsetRandomSampler |
|
|
|
|
|
class SortedSampler(Sampler): |
|
""" Samples elements sequentially, always in the same order. |
|
Args: |
|
data (iterable): Iterable data. |
|
sort_key (callable): Specifies a function of one argument that is used to extract a |
|
numerical comparison key from each list element. |
|
Example: |
|
>>> list(SortedSampler(range(10), sort_key=lambda i: -i)) |
|
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0] |
|
""" |
|
|
|
def __init__(self, |
|
dataset, |
|
sort_key: typing.Callable[[int], typing.Any], |
|
indices: typing.Optional[typing.Iterable[int]] = None): |
|
super().__init__(dataset) |
|
self.dataset = dataset |
|
self.sort_key = sort_key |
|
if indices is None: |
|
sort_keys = map(sort_key, dataset) |
|
else: |
|
sort_keys = ((i, sort_key(dataset[i])) for i in indices) |
|
self.sorted_indices = [i for i, _ in sorted(sort_keys, key=operator.itemgetter(1))] |
|
|
|
def __iter__(self): |
|
return iter(self.sorted_indices) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
|
|
class BucketBatchSampler(BatchSampler): |
|
""" `BucketBatchSampler` toggles between `sampler` batches and sorted batches. |
|
Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between |
|
random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted |
|
and vice versa. Provides ~10-25 percent speedup. |
|
|
|
Background: |
|
``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular |
|
libraries like ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together |
|
examples with a similar size length to reduce the padding required for each batch |
|
while maintaining some noise through bucketing. |
|
|
|
Args: |
|
sampler (torch.data.utils.sampler.Sampler): |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If `True` the sampler will drop the last batch if its size |
|
would be less than `batch_size`. |
|
sort_key (callable, optional): Callable to specify a comparison key for sorting. |
|
bucket_size_multiplier (int, optional): Buckets are of size |
|
`batch_size * bucket_size_multiplier`. |
|
Example: |
|
>>> from torch.utils.data.sampler import SequentialSampler |
|
>>> sampler = SequentialSampler(list(range(10))) |
|
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) |
|
[[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] |
|
>>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
|
""" |
|
|
|
def __init__(self, |
|
sampler, |
|
batch_size, |
|
drop_last, |
|
sort_key, |
|
dataset, |
|
bucket_size_multiplier=100): |
|
super().__init__(sampler, batch_size, drop_last) |
|
self.sort_key = sort_key |
|
self.dataset = dataset |
|
self.bucket_sampler = BatchSampler( |
|
sampler, min(batch_size * bucket_size_multiplier, len(sampler)), False) |
|
|
|
def __iter__(self): |
|
for bucket in self.bucket_sampler: |
|
sorted_sampler = SortedSampler(self.dataset, self.sort_key, indices=bucket) |
|
for batch in SubsetRandomSampler( |
|
list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): |
|
yield batch |
|
|
|
def __len__(self): |
|
if self.drop_last: |
|
return len(self.sampler) // self.batch_size |
|
else: |
|
return math.ceil(len(self.sampler) / self.batch_size) |
|
|