zhenyundeng
add files
e62781a
raw
history blame contribute delete
No virus
4.04 kB
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Data processing/loading helpers."""
import numpy as np
import logging
import unicodedata
from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
from .vector import vectorize
logger = logging.getLogger(__name__)
# ------------------------------------------------------------------------------
# Dictionary class for tokens.
# ------------------------------------------------------------------------------
class Dictionary(object):
NULL = '<NULL>'
UNK = '<UNK>'
START = 2
@staticmethod
def normalize(token):
return unicodedata.normalize('NFD', token)
def __init__(self):
self.tok2ind = {self.NULL: 0, self.UNK: 1}
self.ind2tok = {0: self.NULL, 1: self.UNK}
def __len__(self):
return len(self.tok2ind)
def __iter__(self):
return iter(self.tok2ind)
def __contains__(self, key):
if type(key) == int:
return key in self.ind2tok
elif type(key) == str:
return self.normalize(key) in self.tok2ind
def __getitem__(self, key):
if type(key) == int:
return self.ind2tok.get(key, self.UNK)
if type(key) == str:
return self.tok2ind.get(self.normalize(key),
self.tok2ind.get(self.UNK))
def __setitem__(self, key, item):
if type(key) == int and type(item) == str:
self.ind2tok[key] = item
elif type(key) == str and type(item) == int:
self.tok2ind[key] = item
else:
raise RuntimeError('Invalid (key, item) types.')
def add(self, token):
token = self.normalize(token)
if token not in self.tok2ind:
index = len(self.tok2ind)
self.tok2ind[token] = index
self.ind2tok[index] = token
def tokens(self):
"""Get dictionary tokens.
Return all the words indexed by this dictionary, except for special
tokens.
"""
tokens = [k for k in self.tok2ind.keys()
if k not in {'<NULL>', '<UNK>'}]
return tokens
# ------------------------------------------------------------------------------
# PyTorch dataset class for SQuAD (and SQuAD-like) data.
# ------------------------------------------------------------------------------
class ReaderDataset(Dataset):
def __init__(self, examples, model, single_answer=False):
self.model = model
self.examples = examples
self.single_answer = single_answer
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
return vectorize(self.examples[index], self.model, self.single_answer)
def lengths(self):
return [(len(ex['document']), len(ex['question']))
for ex in self.examples]
# ------------------------------------------------------------------------------
# PyTorch sampler returning batched of sorted lengths (by doc and question).
# ------------------------------------------------------------------------------
class SortedBatchSampler(Sampler):
def __init__(self, lengths, batch_size, shuffle=True):
self.lengths = lengths
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
lengths = np.array(
[(-l[0], -l[1], np.random.random()) for l in self.lengths],
dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)]
)
indices = np.argsort(lengths, order=('l1', 'l2', 'rand'))
batches = [indices[i:i + self.batch_size]
for i in range(0, len(indices), self.batch_size)]
if self.shuffle:
np.random.shuffle(batches)
return iter([i for batch in batches for i in batch])
def __len__(self):
return len(self.lengths)