from typing import Union, List, Tuple, Sequence, Dict, Any, Optional, Collection from copy import copy from pathlib import Path import pickle as pkl import logging import random import lmdb import numpy as np import torch import torch.nn.functional as F from torch.utils.data import Dataset from scipy.spatial.distance import pdist, squareform from .tokenizers import TAPETokenizer from .registry import registry logger = logging.getLogger(__name__) def dataset_factory(data_file: Union[str, Path], *args, **kwargs) -> Dataset: data_file = Path(data_file) if not data_file.exists(): raise FileNotFoundError(data_file) if data_file.suffix == '.lmdb': return LMDBDataset(data_file, *args, **kwargs) elif data_file.suffix in {'.fasta', '.fna', '.ffn', '.faa', '.frn'}: return FastaDataset(data_file, *args, **kwargs) elif data_file.suffix == '.json': return JSONDataset(data_file, *args, **kwargs) elif data_file.is_dir(): return NPZDataset(data_file, *args, **kwargs) else: raise ValueError(f"Unrecognized datafile type {data_file.suffix}") def pad_sequences(sequences: Sequence, constant_value=0, dtype=None) -> np.ndarray: batch_size = len(sequences) shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist() if dtype is None: dtype = sequences[0].dtype if isinstance(sequences[0], np.ndarray): array = np.full(shape, constant_value, dtype=dtype) elif isinstance(sequences[0], torch.Tensor): array = torch.full(shape, constant_value, dtype=dtype) for arr, seq in zip(array, sequences): arrslice = tuple(slice(dim) for dim in seq.shape) arr[arrslice] = seq return array class FastaDataset(Dataset): """Creates a dataset from a fasta file. Args: data_file (Union[str, Path]): Path to fasta file. in_memory (bool, optional): Whether to load the full dataset into memory. Default: False. """ def __init__(self, data_file: Union[str, Path], in_memory: bool = False): from Bio import SeqIO data_file = Path(data_file) if not data_file.exists(): raise FileNotFoundError(data_file) # if in_memory: cache = list(SeqIO.parse(str(data_file), 'fasta')) num_examples = len(cache) self._cache = cache # else: # records = SeqIO.index(str(data_file), 'fasta') # num_examples = len(records) # # if num_examples < 10000: # logger.info("Reading full fasta file into memory because number of examples " # "is very low. This loads data approximately 20x faster.") # in_memory = True # cache = list(records.values()) # self._cache = cache # else: # self._records = records # self._keys = list(records.keys()) self._in_memory = in_memory self._num_examples = num_examples def __len__(self) -> int: return self._num_examples def __getitem__(self, index: int): if not 0 <= index < self._num_examples: raise IndexError(index) # if self._in_memory and self._cache[index] is not None: record = self._cache[index] # else: # key = self._keys[index] # record = self._records[key] # if self._in_memory: # self._cache[index] = record item = {'id': record.id, 'primary': str(record.seq), 'protein_length': len(record.seq)} return item class LMDBDataset(Dataset): """Creates a dataset from an lmdb file. Args: data_file (Union[str, Path]): Path to lmdb file. in_memory (bool, optional): Whether to load the full dataset into memory. Default: False. """ def __init__(self, data_file: Union[str, Path], in_memory: bool = False): data_file = Path(data_file) if not data_file.exists(): raise FileNotFoundError(data_file) env = lmdb.open(str(data_file), max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) with env.begin(write=False) as txn: num_examples = pkl.loads(txn.get(b'num_examples')) if in_memory: cache = [None] * num_examples self._cache = cache self._env = env self._in_memory = in_memory self._num_examples = num_examples def __len__(self) -> int: return self._num_examples def __getitem__(self, index: int): if not 0 <= index < self._num_examples: raise IndexError(index) if self._in_memory and self._cache[index] is not None: item = self._cache[index] else: with self._env.begin(write=False) as txn: item = pkl.loads(txn.get(str(index).encode())) if 'id' not in item: item['id'] = str(index) if self._in_memory: self._cache[index] = item return item class JSONDataset(Dataset): """Creates a dataset from a json file. Assumes that data is a JSON serialized list of record, where each record is a dictionary. Args: data_file (Union[str, Path]): Path to json file. in_memory (bool): Dummy variable to match API of other datasets """ def __init__(self, data_file: Union[str, Path], in_memory: bool = True): import json data_file = Path(data_file) if not data_file.exists(): raise FileNotFoundError(data_file) records = json.loads(data_file.read_text()) if not isinstance(records, list): raise TypeError(f"TAPE JSONDataset requires a json serialized list, " f"received {type(records)}") self._records = records self._num_examples = len(records) def __len__(self) -> int: return self._num_examples def __getitem__(self, index: int): if not 0 <= index < self._num_examples: raise IndexError(index) item = self._records[index] if not isinstance(item, dict): raise TypeError(f"Expected dataset to contain a list of dictionary " f"records, received record of type {type(item)}") if 'id' not in item: item['id'] = str(index) return item class NPZDataset(Dataset): """Creates a dataset from a directory of npz files. Args: data_file (Union[str, Path]): Path to directory of npz files in_memory (bool): Dummy variable to match API of other datasets """ def __init__(self, data_file: Union[str, Path], in_memory: bool = True, split_files: Optional[Collection[str]] = None): data_file = Path(data_file) if not data_file.exists(): raise FileNotFoundError(data_file) if not data_file.is_dir(): raise NotADirectoryError(data_file) file_glob = data_file.glob('*.npz') if split_files is None: file_list = list(file_glob) else: split_files = set(split_files) if len(split_files) == 0: raise ValueError("Passed an empty split file set") file_list = [f for f in file_glob if f.name in split_files] if len(file_list) != len(split_files): num_missing = len(split_files) - len(file_list) raise FileNotFoundError( f"{num_missing} specified split files not found in directory") if len(file_list) == 0: raise FileNotFoundError(f"No .npz files found in {data_file}") self._file_list = file_list def __len__(self) -> int: return len(self._file_list) def __getitem__(self, index: int): if not 0 <= index < len(self): raise IndexError(index) item = dict(np.load(self._file_list[index])) if not isinstance(item, dict): raise TypeError(f"Expected dataset to contain a list of dictionary " f"records, received record of type {type(item)}") if 'id' not in item: item['id'] = self._file_list[index].stem return item @registry.register_task('embed') class EmbedDataset(Dataset): def __init__(self, data_file: Union[str, Path], tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False, convert_tokens_to_ids: bool = True): super().__init__() if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer self.data = dataset_factory(data_file) def __len__(self) -> int: return len(self.data) def __getitem__(self, index: int): item = self.data[index] token_ids = self.tokenizer.encode(item['primary']) input_mask = np.ones_like(token_ids) return item['id'], token_ids, input_mask def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: ids, tokens, input_mask = zip(*batch) ids = list(ids) tokens = torch.from_numpy(pad_sequences(tokens)) input_mask = torch.from_numpy(pad_sequences(input_mask)) return {'ids': ids, 'input_ids': tokens, 'input_mask': input_mask} # type: ignore @registry.register_task('masked_language_modeling') class MaskedLanguageModelingDataset(Dataset): """Creates the Masked Language Modeling Pfam Dataset Args: data_path (Union[str, Path]): Path to tape data root. split (str): One of ['train', 'valid', 'holdout'], specifies which data file to load. in_memory (bool, optional): Whether to load the full dataset into memory. Default: False. """ def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): super().__init__() if split not in ('train', 'valid', 'holdout'): raise ValueError( f"Unrecognized split: {split}. " f"Must be one of ['train', 'valid', 'holdout']") if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer data_path = Path(data_path) data_file = f'pfam/pfam_{split}.lmdb' self.data = dataset_factory(data_path / data_file, in_memory) def __len__(self) -> int: return len(self.data) def __getitem__(self, index): item = self.data[index] tokens = self.tokenizer.tokenize(item['primary']) tokens = self.tokenizer.add_special_tokens(tokens) masked_tokens, labels = self._apply_bert_mask(tokens) masked_token_ids = np.array( self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64) input_mask = np.ones_like(masked_token_ids) masked_token_ids = np.array( self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64) return masked_token_ids, input_mask, labels, item['clan'], item['family'] def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]: input_ids, input_mask, lm_label_ids, clan, family = tuple(zip(*batch)) input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) # ignore_index is -1 lm_label_ids = torch.from_numpy(pad_sequences(lm_label_ids, -1)) clan = torch.LongTensor(clan) # type: ignore family = torch.LongTensor(family) # type: ignore return {'input_ids': input_ids, 'input_mask': input_mask, 'targets': lm_label_ids} def _apply_bert_mask(self, tokens: List[str]) -> Tuple[List[str], List[int]]: masked_tokens = copy(tokens) labels = np.zeros([len(tokens)], np.int64) - 1 for i, token in enumerate(tokens): # Tokens begin and end with start_token and stop_token, ignore these if token in (self.tokenizer.start_token, self.tokenizer.stop_token): pass prob = random.random() if prob < 0.15: prob /= 0.15 labels[i] = self.tokenizer.convert_token_to_id(token) if prob < 0.8: # 80% random change to mask token token = self.tokenizer.mask_token elif prob < 0.9: # 10% chance to change to random token token = self.tokenizer.convert_id_to_token( random.randint(0, self.tokenizer.vocab_size - 1)) else: # 10% chance to keep current token pass masked_tokens[i] = token return masked_tokens, labels @registry.register_task('beta_lactamase') class BetaModelingDataset(MaskedLanguageModelingDataset): def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): super().__init__(data_path, split, tokenizer, in_memory) data_path = Path(data_path) data_file = f'unilanguage/{split}_combined.fasta' self.data = dataset_factory(data_path / data_file, in_memory) def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]: input_ids, input_mask, lm_label_ids = tuple(zip(*batch)) input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) # ignore_index is -1 lm_label_ids = torch.from_numpy(pad_sequences(lm_label_ids, -1)) return {'input_ids': input_ids, 'input_mask': input_mask, 'targets': lm_label_ids} def __getitem__(self, index): item = self.data[index] tokens = self.tokenizer.tokenize(item['primary']) tokens = self.tokenizer.add_special_tokens(tokens) masked_tokens, labels = self._apply_bert_mask(tokens) masked_token_ids = np.array( self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64) input_mask = np.ones_like(masked_token_ids) masked_token_ids = np.array( self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64) return masked_token_ids, input_mask, labels @registry.register_task('unilanguage') class UniModelingDataset(MaskedLanguageModelingDataset): def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): super().__init__(data_path, split, tokenizer, in_memory) data_path = Path(data_path) data_file = f'unilanguage/PF00144_full_length_sequences_labeled.fasta' self.data = dataset_factory(data_path / data_file, in_memory) def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]: input_ids, input_mask, lm_label_ids = tuple(zip(*batch)) input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) # ignore_index is -1 lm_label_ids = torch.from_numpy(pad_sequences(lm_label_ids, -1)) return {'input_ids': input_ids, 'input_mask': input_mask, 'targets': lm_label_ids} def __getitem__(self, index): item = self.data[index] tokens = self.tokenizer.tokenize(item['primary']) tokens = self.tokenizer.add_special_tokens(tokens) masked_tokens, labels = self._apply_bert_mask(tokens) masked_token_ids = np.array( self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64) input_mask = np.ones_like(masked_token_ids) masked_token_ids = np.array( self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64) return masked_token_ids, input_mask, labels @registry.register_task('language_modeling') class LanguageModelingDataset(Dataset): """Creates the Language Modeling Pfam Dataset Args: data_path (Union[str, Path]): Path to tape data root. split (str): One of ['train', 'valid', 'holdout'], specifies which data file to load. in_memory (bool, optional): Whether to load the full dataset into memory. Default: False. """ def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): super().__init__() if split not in ('train', 'valid', 'holdout'): raise ValueError( f"Unrecognized split: {split}. " f"Must be one of ['train', 'valid', 'holdout']") if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer data_path = Path(data_path) data_file = f'pfam/pfam_{split}.lmdb' self.data = dataset_factory(data_path / data_file, in_memory) def __len__(self) -> int: return len(self.data) def __getitem__(self, index): item = self.data[index] token_ids = self.tokenizer.encode(item['primary']) input_mask = np.ones_like(token_ids) return token_ids, input_mask, item['clan'], item['family'] def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]: input_ids, input_mask, clan, family = tuple(zip(*batch)) torch_inputs = torch.from_numpy(pad_sequences(input_ids, 0)) input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) # ignore_index is -1 torch_labels = torch.from_numpy(pad_sequences(input_ids, -1)) clan = torch.LongTensor(clan) # type: ignore family = torch.LongTensor(family) # type: ignore return {'input_ids': torch_inputs, 'input_mask': input_mask, 'targets': torch_labels} @registry.register_task('fluorescence') class FluorescenceDataset(Dataset): def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): if split not in ('train', 'valid', 'test'): raise ValueError(f"Unrecognized split: {split}. " f"Must be one of ['train', 'valid', 'test']") if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer data_path = Path(data_path) data_file = f'fluorescence/fluorescence_{split}.lmdb' self.data = dataset_factory(data_path / data_file, in_memory) def __len__(self) -> int: return len(self.data) def __getitem__(self, index: int): item = self.data[index] token_ids = self.tokenizer.encode(item['primary']) input_mask = np.ones_like(token_ids) return token_ids, input_mask, float(item['log_fluorescence'][0]) def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: input_ids, input_mask, fluorescence_true_value = tuple(zip(*batch)) input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) fluorescence_true_value = torch.FloatTensor(fluorescence_true_value) # type: ignore fluorescence_true_value = fluorescence_true_value.unsqueeze(1) return {'input_ids': input_ids, 'input_mask': input_mask, 'targets': fluorescence_true_value} @registry.register_task('stability') class StabilityDataset(Dataset): def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): if split not in ('train', 'valid', 'test'): raise ValueError(f"Unrecognized split: {split}. " f"Must be one of ['train', 'valid', 'test']") if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer data_path = Path(data_path) data_file = f'stability/stability_{split}.lmdb' self.data = dataset_factory(data_path / data_file, in_memory) def __len__(self) -> int: return len(self.data) def __getitem__(self, index: int): item = self.data[index] token_ids = self.tokenizer.encode(item['primary']) input_mask = np.ones_like(token_ids) return token_ids, input_mask, float(item['stability_score'][0]) def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: input_ids, input_mask, stability_true_value = tuple(zip(*batch)) input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) stability_true_value = torch.FloatTensor(stability_true_value) # type: ignore stability_true_value = stability_true_value.unsqueeze(1) return {'input_ids': input_ids, 'input_mask': input_mask, 'targets': stability_true_value} @registry.register_task('remote_homology', num_labels=1195) class RemoteHomologyDataset(Dataset): def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): if split not in ('train', 'valid', 'test_fold_holdout', 'test_family_holdout', 'test_superfamily_holdout'): raise ValueError(f"Unrecognized split: {split}. Must be one of " f"['train', 'valid', 'test_fold_holdout', " f"'test_family_holdout', 'test_superfamily_holdout']") if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer data_path = Path(data_path) data_file = f'remote_homology/remote_homology_{split}.lmdb' self.data = dataset_factory(data_path / data_file, in_memory) def __len__(self) -> int: return len(self.data) def __getitem__(self, index: int): item = self.data[index] token_ids = self.tokenizer.encode(item['primary']) input_mask = np.ones_like(token_ids) return token_ids, input_mask, item['fold_label'] def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: input_ids, input_mask, fold_label = tuple(zip(*batch)) input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) fold_label = torch.LongTensor(fold_label) # type: ignore return {'input_ids': input_ids, 'input_mask': input_mask, 'targets': fold_label} @registry.register_task('contact_prediction') class ProteinnetDataset(Dataset): def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): if split not in ('train', 'train_unfiltered', 'valid', 'test'): raise ValueError(f"Unrecognized split: {split}. Must be one of " f"['train', 'train_unfiltered', 'valid', 'test']") if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer data_path = Path(data_path) data_file = f'proteinnet/proteinnet_{split}.lmdb' self.data = dataset_factory(data_path / data_file, in_memory) def __len__(self) -> int: return len(self.data) def __getitem__(self, index: int): item = self.data[index] protein_length = len(item['primary']) token_ids = self.tokenizer.encode(item['primary']) input_mask = np.ones_like(token_ids) valid_mask = item['valid_mask'] contact_map = np.less(squareform(pdist(item['tertiary'])), 8.0).astype(np.int64) yind, xind = np.indices(contact_map.shape) invalid_mask = ~(valid_mask[:, None] & valid_mask[None, :]) invalid_mask |= np.abs(yind - xind) < 6 contact_map[invalid_mask] = -1 return token_ids, input_mask, contact_map, protein_length def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: input_ids, input_mask, contact_labels, protein_length = tuple(zip(*batch)) input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) contact_labels = torch.from_numpy(pad_sequences(contact_labels, -1)) protein_length = torch.LongTensor(protein_length) # type: ignore return {'input_ids': input_ids, 'input_mask': input_mask, 'targets': contact_labels, 'protein_length': protein_length} @registry.register_task('secondary_structure', num_labels=3) class SecondaryStructureDataset(Dataset): def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False): if split not in ('train', 'valid', 'casp12', 'ts115', 'cb513'): raise ValueError(f"Unrecognized split: {split}. Must be one of " f"['train', 'valid', 'casp12', " f"'ts115', 'cb513']") if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer data_path = Path(data_path) data_file = f'secondary_structure/secondary_structure_{split}.lmdb' self.data = dataset_factory(data_path / data_file, in_memory) def __len__(self) -> int: return len(self.data) def __getitem__(self, index: int): item = self.data[index] token_ids = self.tokenizer.encode(item['primary']) input_mask = np.ones_like(token_ids) # pad with -1s because of cls/sep tokens labels = np.asarray(item['ss3'], np.int64) labels = np.pad(labels, (1, 1), 'constant', constant_values=-1) return token_ids, input_mask, labels def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]: input_ids, input_mask, ss_label = tuple(zip(*batch)) input_ids = torch.from_numpy(pad_sequences(input_ids, 0)) input_mask = torch.from_numpy(pad_sequences(input_mask, 0)) ss_label = torch.from_numpy(pad_sequences(ss_label, -1)) output = {'input_ids': input_ids, 'input_mask': input_mask, 'targets': ss_label} return output @registry.register_task('trrosetta') class TRRosettaDataset(Dataset): def __init__(self, data_path: Union[str, Path], split: str, tokenizer: Union[str, TAPETokenizer] = 'iupac', in_memory: bool = False, max_seqlen: int = 300): if split not in ('train', 'valid'): raise ValueError( f"Unrecognized split: {split}. " f"Must be one of ['train', 'valid']") if isinstance(tokenizer, str): tokenizer = TAPETokenizer(vocab=tokenizer) self.tokenizer = tokenizer data_path = Path(data_path) data_path = data_path / 'trrosetta' split_files = (data_path / f'{split}_files.txt').read_text().split() self.data = NPZDataset(data_path / 'npz', in_memory, split_files=split_files) self._dist_bins = np.arange(2, 20.1, 0.5) self._dihedral_bins = (15 + np.arange(-180, 180, 15)) / 180 * np.pi self._planar_bins = (15 + np.arange(0, 180, 15)) / 180 * np.pi self._split = split self.max_seqlen = max_seqlen self.msa_cutoff = 0.8 self.penalty_coeff = 4.5 def __len__(self) -> int: return len(self.data) def __getitem__(self, index): item = self.data[index] msa = item['msa'] dist = item['dist6d'] omega = item['omega6d'] theta = item['theta6d'] phi = item['phi6d'] if self._split == 'train': msa = self._subsample_msa(msa) elif self._split == 'valid': msa = msa[:20000] # runs out of memory if msa is way too big msa, dist, omega, theta, phi = self._slice_long_sequences( msa, dist, omega, theta, phi) mask = dist == 0 dist_bins = np.digitize(dist, self._dist_bins) omega_bins = np.digitize(omega, self._dihedral_bins) + 1 theta_bins = np.digitize(theta, self._dihedral_bins) + 1 phi_bins = np.digitize(phi, self._planar_bins) + 1 dist_bins[mask] = 0 omega_bins[mask] = 0 theta_bins[mask] = 0 phi_bins[mask] = 0 dist_bins[np.diag_indices_from(dist_bins)] = -1 # input_mask = np.ones_like(msa[0]) return msa, dist_bins, omega_bins, theta_bins, phi_bins def _slice_long_sequences(self, msa, dist, omega, theta, phi): seqlen = msa.shape[1] if self.max_seqlen > 0 and seqlen > self.max_seqlen: start = np.random.randint(seqlen - self.max_seqlen + 1) end = start + self.max_seqlen msa = msa[:, start:end] dist = dist[start:end, start:end] omega = omega[start:end, start:end] theta = theta[start:end, start:end] phi = phi[start:end, start:end] return msa, dist, omega, theta, phi def _subsample_msa(self, msa): num_alignments, seqlen = msa.shape if num_alignments < 10: return msa num_sample = int(10 ** np.random.uniform(np.log10(num_alignments)) - 10) if num_sample <= 0: return msa[0][None, :] elif num_sample > 20000: num_sample = 20000 indices = np.random.choice( msa.shape[0] - 1, size=num_sample, replace=False) + 1 indices = np.pad(indices, [1, 0], 'constant') # add the sequence back in return msa[indices] def collate_fn(self, batch): msa, dist_bins, omega_bins, theta_bins, phi_bins = tuple(zip(*batch)) # features = pad_sequences([self.featurize(msa_) for msa_ in msa], 0) msa1hot = pad_sequences( [F.one_hot(torch.LongTensor(msa_), 21) for msa_ in msa], 0, torch.float) # input_mask = torch.FloatTensor(pad_sequences(input_mask, 0)) dist_bins = torch.LongTensor(pad_sequences(dist_bins, -1)) omega_bins = torch.LongTensor(pad_sequences(omega_bins, 0)) theta_bins = torch.LongTensor(pad_sequences(theta_bins, 0)) phi_bins = torch.LongTensor(pad_sequences(phi_bins, 0)) return {'msa1hot': msa1hot, # 'input_mask': input_mask, 'dist': dist_bins, 'omega': omega_bins, 'theta': theta_bins, 'phi': phi_bins} def featurize(self, msa): msa = torch.LongTensor(msa) msa1hot = F.one_hot(msa, 21).float() seqlen = msa1hot.size(1) weights = self.reweight(msa1hot) features_1d = self.extract_features_1d(msa1hot, weights) features_2d = self.extract_features_2d(msa1hot, weights) features = torch.cat(( features_1d.unsqueeze(1).repeat(1, seqlen, 1), features_1d.unsqueeze(0).repeat(seqlen, 1, 1), features_2d), -1) features = features.permute(2, 0, 1) return features def reweight(self, msa1hot): # Reweight seqlen = msa1hot.size(1) id_min = seqlen * self.msa_cutoff id_mtx = torch.tensordot(msa1hot, msa1hot, [[1, 2], [1, 2]]) id_mask = id_mtx > id_min weights = 1.0 / id_mask.float().sum(-1) return weights def extract_features_1d(self, msa1hot, weights): # 1D Features seqlen = msa1hot.size(1) f1d_seq = msa1hot[0, :, :20] # msa2pssm beff = weights.sum() f_i = (weights[:, None, None] * msa1hot).sum(0) / beff + 1e-9 h_i = (-f_i * f_i.log()).sum(1, keepdims=True) f1d_pssm = torch.cat((f_i, h_i), dim=1) f1d = torch.cat((f1d_seq, f1d_pssm), dim=1) f1d = f1d.view(seqlen, 42) return f1d def extract_features_2d(self, msa1hot, weights): # 2D Features num_alignments = msa1hot.size(0) seqlen = msa1hot.size(1) num_symbols = 21 if num_alignments == 1: # No alignments, predict from sequence alone f2d_dca = torch.zeros(seqlen, seqlen, 442, dtype=torch.float) else: # fast_dca # covariance x = msa1hot.view(num_alignments, seqlen * num_symbols) num_points = weights.sum() - weights.mean().sqrt() mean = (x * weights[:, None]).sum(0, keepdims=True) / num_points x = (x - mean) * weights[:, None].sqrt() cov = torch.matmul(x.transpose(-1, -2), x) / num_points # inverse covariance reg = torch.eye(seqlen * num_symbols) * self.penalty_coeff / weights.sum().sqrt() cov_reg = cov + reg inv_cov = torch.inverse(cov_reg) x1 = inv_cov.view(seqlen, num_symbols, seqlen, num_symbols) x2 = x1.permute(0, 2, 1, 3) features = x2.reshape(seqlen, seqlen, num_symbols * num_symbols) x3 = (x1[:, :-1, :, :-1] ** 2).sum((1, 3)).sqrt() * (1 - torch.eye(seqlen)) apc = x3.sum(0, keepdims=True) * x3.sum(1, keepdims=True) / x3.sum() contacts = (x3 - apc) * (1 - torch.eye(seqlen)) f2d_dca = torch.cat([features, contacts[:, :, None]], axis=2) return f2d_dca