|
from typing import Dict, Type, Callable, Optional, Union |
|
from torch.utils.data import Dataset |
|
from .models.modeling_utils import ProteinModel |
|
from pathlib import Path |
|
|
|
PathType = Union[str, Path] |
|
|
|
|
|
def convert_model_args(model_args): |
|
d = {} |
|
for e in model_args: |
|
k, v = e.split("=") |
|
try: |
|
v = int(v) |
|
except: |
|
try: |
|
v = float(v) |
|
except: |
|
v = str(v) |
|
d[k] = v |
|
return d |
|
|
|
|
|
class TAPETaskSpec: |
|
""" |
|
Attributes |
|
---------- |
|
name (str): |
|
The name of the TAPE task |
|
dataset (Type[Dataset]): |
|
The dataset used in the TAPE task |
|
num_labels (int): |
|
number of labels used if this is a classification task |
|
models (Dict[str, ProteinModel]): |
|
The set of models that can be used for this task. Default: {}. |
|
""" |
|
|
|
def __init__(self, |
|
name: str, |
|
dataset: Type[Dataset], |
|
num_labels: int = -1, |
|
models: Optional[Dict[str, Type[ProteinModel]]] = None): |
|
self.name = name |
|
self.dataset = dataset |
|
self.num_labels = num_labels |
|
self.models = models if models is not None else {} |
|
|
|
def register_model(self, model_name: str, model_cls: Optional[Type[ProteinModel]] = None): |
|
if model_cls is not None: |
|
if model_name in self.models: |
|
raise KeyError( |
|
f"A model with name '{model_name}' is already registered for this task") |
|
self.models[model_name] = model_cls |
|
return model_cls |
|
else: |
|
return lambda model_cls: self.register_model(model_name, model_cls) |
|
|
|
def get_model(self, model_name: str) -> Type[ProteinModel]: |
|
return self.models[model_name] |
|
|
|
|
|
class Registry: |
|
r"""Class for registry object which acts as the |
|
central repository for TAPE.""" |
|
|
|
task_name_mapping: Dict[str, TAPETaskSpec] = {} |
|
metric_name_mapping: Dict[str, Callable] = {} |
|
|
|
@classmethod |
|
def register_task(cls, |
|
task_name: str, |
|
num_labels: int = -1, |
|
dataset: Optional[Type[Dataset]] = None, |
|
models: Optional[Dict[str, Type[ProteinModel]]] = None): |
|
""" Register a a new TAPE task. This creates a new TAPETaskSpec. |
|
|
|
Args: |
|
|
|
task_name (str): The name of the TAPE task. |
|
num_labels (int): Number of labels used if this is a classification task. If this |
|
is not a classification task, simply leave the default as -1. |
|
dataset (Type[Dataset]): The dataset used in the TAPE task. |
|
models (Optional[Dict[str, ProteinModel]]): The set of models that can be used for |
|
this task. If you do not pass this argument, you can register models to the task |
|
later by using `registry.register_task_model`. Default: {}. |
|
|
|
Examples: |
|
|
|
There are two ways of registering a new task. First, one can define the task by simply |
|
declaring all the components, and then calling the register method, like so: |
|
|
|
class SecondaryStructureDataset(Dataset): |
|
... |
|
|
|
class ProteinBertForSequenceToSequenceClassification(): |
|
... |
|
|
|
registry.register_task( |
|
'secondary_structure', 3, SecondaryStructureDataset, |
|
{'transformer': ProteinBertForSequenceToSequenceClassification}) |
|
|
|
This will register a new task, 'secondary_structure', with a single model. More models |
|
can be added with `registry.register_task_model`. Alternatively, this can be used as a |
|
decorator: |
|
|
|
@registry.regsiter_task('secondary_structure', 3) |
|
class SecondaryStructureDataset(Dataset): |
|
... |
|
|
|
@registry.register_task_model('secondary_structure', 'transformer') |
|
class ProteinBertForSequenceToSequenceClassification(): |
|
... |
|
|
|
These two pieces of code are exactly equivalent, in terms of the resulting registry |
|
state. |
|
|
|
""" |
|
if dataset is not None: |
|
if models is None: |
|
models = {} |
|
task_spec = TAPETaskSpec(task_name, dataset, num_labels, models) |
|
return cls.register_task_spec(task_name, task_spec).dataset |
|
else: |
|
return lambda dataset: cls.register_task(task_name, num_labels, dataset, models) |
|
|
|
@classmethod |
|
def register_task_spec(cls, task_name: str, task_spec: Optional[TAPETaskSpec] = None): |
|
""" Registers a task_spec directly. If you find it easier to actually create a |
|
TAPETaskSpec manually, and then register it, feel free to use this method, |
|
but otherwise it is likely easier to use `registry.register_task`. |
|
""" |
|
if task_spec is not None: |
|
if task_name in cls.task_name_mapping: |
|
raise KeyError(f"A task with name '{task_name}' is already registered") |
|
cls.task_name_mapping[task_name] = task_spec |
|
return task_spec |
|
else: |
|
return lambda task_spec: cls.register_task_spec(task_name, task_spec) |
|
|
|
@classmethod |
|
def register_task_model(cls, |
|
task_name: str, |
|
model_name: str, |
|
model_cls: Optional[Type[ProteinModel]] = None): |
|
r"""Register a specific model to a task with the provided model name. |
|
The task must already be in the registry - you cannot register a |
|
model to an unregistered task. |
|
|
|
Args: |
|
task_name (str): Name of task to which to register the model. |
|
model_name (str): Name of model to use when registering task, this |
|
is the name that you will use to refer to the model on the |
|
command line. |
|
model_cls (Type[ProteinModel]): The model to register. |
|
|
|
Examples: |
|
|
|
As with `registry.register_task`, this can both be used as a regular |
|
python function, and as a decorator. For example this: |
|
|
|
class ProteinBertForSequenceToSequenceClassification(): |
|
... |
|
registry.register_task_model( |
|
'secondary_structure', 'transformer', |
|
ProteinBertForSequenceToSequenceClassification) |
|
|
|
and as a decorator: |
|
|
|
@registry.register_task_model('secondary_structure', 'transformer') |
|
class ProteinBertForSequenceToSequenceClassification(): |
|
... |
|
|
|
are both equivalent. |
|
""" |
|
if task_name not in cls.task_name_mapping: |
|
raise KeyError( |
|
f"Tried to register a task model for an unregistered task: {task_name}. " |
|
f"Make sure to register the task {task_name} first.") |
|
return cls.task_name_mapping[task_name].register_model(model_name, model_cls) |
|
|
|
@classmethod |
|
def register_metric(cls, name: str) -> Callable[[Callable], Callable]: |
|
r"""Register a metric to registry with key 'name' |
|
|
|
Args: |
|
name: Key with which the metric will be registered. |
|
|
|
Usage:: |
|
from tape.registry import registry |
|
|
|
@registry.register_metric('mse') |
|
def mean_squred_error(inputs, outputs): |
|
... |
|
""" |
|
|
|
def wrap(fn: Callable) -> Callable: |
|
assert callable(fn), "All metrics must be callable" |
|
cls.metric_name_mapping[name] = fn |
|
return fn |
|
|
|
return wrap |
|
|
|
@classmethod |
|
def get_task_spec(cls, name: str) -> TAPETaskSpec: |
|
return cls.task_name_mapping[name] |
|
|
|
@classmethod |
|
def get_metric(cls, name: str) -> Callable: |
|
return cls.metric_name_mapping[name] |
|
|
|
@classmethod |
|
def get_task_model(cls, |
|
model_name: str, |
|
task_name: str, |
|
config_file: Optional[PathType] = None, |
|
load_dir: Optional[PathType] = None, |
|
model_args = None) -> ProteinModel: |
|
""" Create a TAPE task model, either from scratch or from a pretrained model. |
|
This is mostly a helper function that evaluates the if statements in a |
|
sensible order if you pass all three of the arguments. |
|
Args: |
|
model_name (str): Which type of model to create (e.g. transformer, unirep, ...) |
|
task_name (str): The TAPE task for which to create a model |
|
config_file (str, optional): A json config file that specifies hyperparameters |
|
load_dir (str, optional): A save directory for a pretrained model |
|
Returns: |
|
model (ProteinModel): A TAPE task model |
|
""" |
|
task_spec = registry.get_task_spec(task_name) |
|
model_cls = task_spec.get_model(model_name) |
|
|
|
if load_dir is not None: |
|
model = model_cls.from_pretrained(load_dir, num_labels=task_spec.num_labels) |
|
else: |
|
config_class = model_cls.config_class |
|
if config_file is not None: |
|
config = config_class.from_json_file(config_file) |
|
else: |
|
config = config_class() |
|
|
|
if model_args: |
|
model_args = convert_model_args(model_args) |
|
for k,v in model_args.items(): |
|
if k in config.__dict__ and type(config.__dict__[k])==type(v): |
|
setattr(config, k, v) |
|
else: |
|
raise ValueError(f"model arg {k} not in config or of the same type as default") |
|
|
|
config.num_labels = task_spec.num_labels |
|
model = model_cls(config) |
|
return model |
|
|
|
|
|
registry = Registry() |
|
|