#!/usr/bin/env python3 # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Data processing/loading helpers.""" import numpy as np import logging import unicodedata from torch.utils.data import Dataset from torch.utils.data.sampler import Sampler from .vector import vectorize logger = logging.getLogger(__name__) # ------------------------------------------------------------------------------ # Dictionary class for tokens. # ------------------------------------------------------------------------------ class Dictionary(object): NULL = '' UNK = '' START = 2 @staticmethod def normalize(token): return unicodedata.normalize('NFD', token) def __init__(self): self.tok2ind = {self.NULL: 0, self.UNK: 1} self.ind2tok = {0: self.NULL, 1: self.UNK} def __len__(self): return len(self.tok2ind) def __iter__(self): return iter(self.tok2ind) def __contains__(self, key): if type(key) == int: return key in self.ind2tok elif type(key) == str: return self.normalize(key) in self.tok2ind def __getitem__(self, key): if type(key) == int: return self.ind2tok.get(key, self.UNK) if type(key) == str: return self.tok2ind.get(self.normalize(key), self.tok2ind.get(self.UNK)) def __setitem__(self, key, item): if type(key) == int and type(item) == str: self.ind2tok[key] = item elif type(key) == str and type(item) == int: self.tok2ind[key] = item else: raise RuntimeError('Invalid (key, item) types.') def add(self, token): token = self.normalize(token) if token not in self.tok2ind: index = len(self.tok2ind) self.tok2ind[token] = index self.ind2tok[index] = token def tokens(self): """Get dictionary tokens. Return all the words indexed by this dictionary, except for special tokens. """ tokens = [k for k in self.tok2ind.keys() if k not in {'', ''}] return tokens # ------------------------------------------------------------------------------ # PyTorch dataset class for SQuAD (and SQuAD-like) data. # ------------------------------------------------------------------------------ class ReaderDataset(Dataset): def __init__(self, examples, model, single_answer=False): self.model = model self.examples = examples self.single_answer = single_answer def __len__(self): return len(self.examples) def __getitem__(self, index): return vectorize(self.examples[index], self.model, self.single_answer) def lengths(self): return [(len(ex['document']), len(ex['question'])) for ex in self.examples] # ------------------------------------------------------------------------------ # PyTorch sampler returning batched of sorted lengths (by doc and question). # ------------------------------------------------------------------------------ class SortedBatchSampler(Sampler): def __init__(self, lengths, batch_size, shuffle=True): self.lengths = lengths self.batch_size = batch_size self.shuffle = shuffle def __iter__(self): lengths = np.array( [(-l[0], -l[1], np.random.random()) for l in self.lengths], dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)] ) indices = np.argsort(lengths, order=('l1', 'l2', 'rand')) batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)] if self.shuffle: np.random.shuffle(batches) return iter([i for batch in batches for i in batch]) def __len__(self): return len(self.lengths)