File size: 13,000 Bytes
212111c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
import typing
import os
import logging
import argparse
import warnings
import inspect
try:
import apex # noqa: F401
APEX_FOUND = True
except ImportError:
APEX_FOUND = False
from .registry import registry
from . import training
from . import utils
CallbackList = typing.Sequence[typing.Callable]
OutputDict = typing.Dict[str, typing.List[typing.Any]]
logger = logging.getLogger(__name__)
warnings.filterwarnings( # Ignore pytorch warning about loss gathering
'ignore', message='Was asked to gather along dimension 0', module='torch.nn.parallel')
def create_base_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description='Parent parser for tape functions',
add_help=False)
parser.add_argument('model_type', help='Base model class to run')
parser.add_argument('--model_config_file', default=None, type=utils.check_is_file,
help='Config file for model')
parser.add_argument('--vocab_file', default=None,
help='Pretrained tokenizer vocab file')
parser.add_argument('--output_dir', default='./results', type=str)
parser.add_argument('--no_cuda', action='store_true', help='CPU-only flag')
parser.add_argument('--seed', default=42, type=int, help='Random seed to use')
parser.add_argument('--local_rank', type=int, default=-1,
help='Local rank of process in distributed training. '
'Set by launch script.')
parser.add_argument('--tokenizer', choices=['iupac', 'unirep'],
default='iupac', help='Tokenizes to use on the amino acid sequences')
parser.add_argument('--num_workers', default=8, type=int,
help='Number of workers to use for multi-threaded data loading')
parser.add_argument('--log_level', default=logging.INFO,
choices=['DEBUG', 'INFO', 'WARN', 'WARNING', 'ERROR',
logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR],
help="log level for the experiment")
parser.add_argument('--debug', action='store_true', help='Run in debug mode')
return parser
def create_train_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description='Run Training on the TAPE datasets',
parents=[base_parser])
parser.add_argument('task', choices=list(registry.task_name_mapping.keys()),
help='TAPE Task to train/eval on')
parser.add_argument('--learning_rate', default=1e-4, type=float,
help='Learning rate')
parser.add_argument('--batch_size', default=1024, type=int,
help='Batch size')
parser.add_argument('--data_dir', default='./data', type=utils.check_is_dir,
help='Directory from which to load task data')
parser.add_argument('--num_train_epochs', default=10, type=int,
help='Number of training epochs')
parser.add_argument('--num_steps_per_epoch', default=-1, type=int,
help='Number of steps per epoch')
parser.add_argument('--num_log_iter', default=20, type=int,
help='Number of training steps per log iteration')
parser.add_argument('--fp16', action='store_true', help='Whether to use fp16 weights')
parser.add_argument('--warmup_steps', default=10000, type=int,
help='Number of learning rate warmup steps')
parser.add_argument('--gradient_accumulation_steps', default=1, type=int,
help='Number of forward passes to make for each backwards pass')
parser.add_argument('--loss_scale', default=0, type=int,
help='Loss scaling. Only used during fp16 training.')
parser.add_argument('--max_grad_norm', default=1.0, type=float,
help='Maximum gradient norm')
parser.add_argument('--exp_name', default=None, type=str,
help='Name to give to this experiment')
parser.add_argument('--from_pretrained', default=None, type=str,
help='Directory containing config and pretrained model weights')
parser.add_argument('--log_dir', default='./logs', type=str)
parser.add_argument('--eval_freq', type=int, default=1,
help="Frequency of eval pass. A value <= 0 means the eval pass is "
"not run")
parser.add_argument('--save_freq', default='improvement', type=utils.int_or_str,
help="How often to save the model during training. Either an integer "
"frequency or the string 'improvement'")
parser.add_argument('--patience', default=-1, type=int,
help="How many epochs without improvement to wait before ending "
"training")
parser.add_argument('--resume_from_checkpoint', action='store_true',
help="whether to resume training from the checkpoint")
parser.add_argument('--val_check_frac', default=1.0, type=float,
help="Fraction of validation to check")
return parser
def create_eval_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description='Run Eval on the TAPE Datasets',
parents=[base_parser])
parser.add_argument('task', choices=list(registry.task_name_mapping.keys()),
help='TAPE Task to train/eval on')
parser.add_argument('from_pretrained', type=str,
help='Directory containing config and pretrained model weights')
parser.add_argument('--batch_size', default=1024, type=int,
help='Batch size')
parser.add_argument('--data_dir', default='./data', type=utils.check_is_dir,
help='Directory from which to load task data')
parser.add_argument('--metrics', default=[],
help=f'Metrics to run on the result. '
f'Choices: {list(registry.metric_name_mapping.keys())}',
nargs='*')
parser.add_argument('--split', default='test', type=str,
help='Which split to run on')
return parser
def create_embed_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description='Embed a set of proteins with a pretrained model',
parents=[base_parser])
parser.add_argument('data_file', type=str,
help='File containing set of proteins to embed')
parser.add_argument('out_file', type=str,
help='Name of output file')
parser.add_argument('from_pretrained', type=str,
help='Directory containing config and pretrained model weights')
parser.add_argument('--batch_size', default=1024, type=int,
help='Batch size')
parser.add_argument('--full_sequence_embed', action='store_true',
help='If true, saves an embedding at every amino acid position '
'in the sequence. Note that this can take a large amount '
'of disk space.')
parser.set_defaults(task='embed')
return parser
def create_distributed_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(add_help=False, parents=[base_parser])
# typing.Optional arguments for the launch helper
parser.add_argument("--nnodes", type=int, default=1,
help="The number of nodes to use for distributed "
"training")
parser.add_argument("--node_rank", type=int, default=0,
help="The rank of the node for multi-node distributed "
"training")
parser.add_argument("--nproc_per_node", type=int, default=1,
help="The number of processes to launch on each node, "
"for GPU training, this is recommended to be set "
"to the number of GPUs in your system so that "
"each process can be bound to a single GPU.")
parser.add_argument("--master_addr", default="127.0.0.1", type=str,
help="Master node (rank 0)'s address, should be either "
"the IP address or the hostname of node 0, for "
"single node multi-proc training, the "
"--master_addr can simply be 127.0.0.1")
parser.add_argument("--master_port", default=29500, type=int,
help="Master node (rank 0)'s free port that needs to "
"be used for communciation during distributed "
"training")
return parser
def create_model_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(add_help=False, parents=[base_parser])
parser.add_argument('--model_args', nargs=argparse.REMAINDER, default=None)
return parser
def run_train(args: typing.Optional[argparse.Namespace] = None, env=None) -> None:
if env is not None:
os.environ = env
if args is None:
base_parser = create_base_parser()
train_parser = create_train_parser(base_parser)
model_parser = create_model_parser(train_parser)
args = model_parser.parse_args()
if args.gradient_accumulation_steps < 1:
raise ValueError(
f"Invalid gradient_accumulation_steps parameter: "
f"{args.gradient_accumulation_steps}, should be >= 1")
if (args.fp16 or args.local_rank != -1) and not APEX_FOUND:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex "
"to use distributed and fp16 training.")
arg_dict = vars(args)
arg_names = inspect.getfullargspec(training.run_train).args
missing = set(arg_names) - set(arg_dict.keys())
if missing:
raise RuntimeError(f"Missing arguments: {missing}")
train_args = {name: arg_dict[name] for name in arg_names}
training.run_train(**train_args)
def run_eval(args: typing.Optional[argparse.Namespace] = None) -> typing.Dict[str, float]:
if args is None:
base_parser = create_base_parser()
parser = create_eval_parser(base_parser)
parser = create_model_parser(parser)
args = parser.parse_args()
if args.from_pretrained is None:
raise ValueError("Must specify pretrained model")
if args.local_rank != -1:
raise ValueError("TAPE does not support distributed validation pass")
arg_dict = vars(args)
arg_names = inspect.getfullargspec(training.run_eval).args
missing = set(arg_names) - set(arg_dict.keys())
if missing:
raise RuntimeError(f"Missing arguments: {missing}")
eval_args = {name: arg_dict[name] for name in arg_names}
return training.run_eval(**eval_args)
def run_embed(args: typing.Optional[argparse.Namespace] = None) -> None:
if args is None:
base_parser = create_base_parser()
parser = create_embed_parser(base_parser)
parser = create_model_parser(parser)
args = parser.parse_args()
if args.from_pretrained is None:
raise ValueError("Must specify pretrained model")
if args.local_rank != -1:
raise ValueError("TAPE does not support distributed validation pass")
arg_dict = vars(args)
arg_names = inspect.getfullargspec(training.run_embed).args
missing = set(arg_names) - set(arg_dict.keys())
if missing:
raise RuntimeError(f"Missing arguments: {missing}")
embed_args = {name: arg_dict[name] for name in arg_names}
training.run_embed(**embed_args)
def run_train_distributed(args: typing.Optional[argparse.Namespace] = None) -> None:
"""Runs distributed training via multiprocessing.
"""
if args is None:
base_parser = create_base_parser()
distributed_parser = create_distributed_parser(base_parser)
distributed_train_parser = create_train_parser(distributed_parser)
parser = create_model_parser(distributed_train_parser)
args = parser.parse_args()
# Define the experiment name here, instead of dealing with barriers and communication
# when getting the experiment name
exp_name = utils.get_expname(args.exp_name, args.task, args.model_type)
args.exp_name = exp_name
utils.launch_process_group(
run_train, args, args.nproc_per_node, args.nnodes,
args.node_rank, args.master_addr, args.master_port)
if __name__ == '__main__':
run_train_distributed()
|