GH29BERT / tape /utils /setup_utils.py
KeXing
Upload 26 files
212111c
raw
history blame
4.42 kB
"""Utility functions to help setup the model, optimizer, distributed compute, etc.
"""
import typing
import logging
from pathlib import Path
import sys
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, RandomSampler, Dataset
from torch.utils.data.distributed import DistributedSampler
from ..optimization import AdamW
from ..registry import registry
from .utils import get_effective_batch_size
from ._sampler import BucketBatchSampler
logger = logging.getLogger(__name__)
def setup_logging(local_rank: int,
save_path: typing.Optional[Path] = None,
log_level: typing.Union[str, int] = None) -> None:
if log_level is None:
level = logging.INFO
elif isinstance(log_level, str):
level = getattr(logging, log_level.upper())
elif isinstance(log_level, int):
level = log_level
if local_rank not in (-1, 0):
level = max(level, logging.WARN)
root_logger = logging.getLogger()
root_logger.setLevel(level)
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%y/%m/%d %H:%M:%S")
if not root_logger.hasHandlers():
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(level)
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
if save_path is not None:
file_handler = logging.FileHandler(save_path / 'log')
file_handler.setLevel(level)
file_handler.setFormatter(formatter)
root_logger.addHandler(file_handler)
def setup_optimizer(model,
learning_rate: float):
"""Create the AdamW optimizer for the given model with the specified learning rate. Based on
creation in the pytorch_transformers repository.
Args:
model (PreTrainedModel): The model for which to create an optimizer
learning_rate (float): Default learning rate to use when creating the optimizer
Returns:
optimizer (AdamW): An AdamW optimizer
"""
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
optimizer_grouped_parameters = [
{
"params": [
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
],
"weight_decay": 0.01,
},
{
"params": [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
return optimizer
def setup_dataset(task: str,
data_dir: typing.Union[str, Path],
split: str,
tokenizer: str) -> Dataset:
task_spec = registry.get_task_spec(task)
return task_spec.dataset(data_dir, split, tokenizer) # type: ignore
def setup_loader(dataset: Dataset,
batch_size: int,
local_rank: int,
n_gpu: int,
gradient_accumulation_steps: int,
num_workers: int) -> DataLoader:
sampler = DistributedSampler(dataset) if local_rank != -1 else RandomSampler(dataset)
batch_size = get_effective_batch_size(
batch_size, local_rank, n_gpu, gradient_accumulation_steps) * n_gpu
# WARNING: this will fail if the primary sequence is not the first thing the dataset returns
batch_sampler = BucketBatchSampler(
sampler, batch_size, False, lambda x: len(x[0]), dataset)
loader = DataLoader(
dataset,
num_workers=num_workers,
collate_fn=dataset.collate_fn, # type: ignore
batch_sampler=batch_sampler)
return loader
def setup_distributed(local_rank: int,
no_cuda: bool) -> typing.Tuple[torch.device, int, bool]:
if local_rank != -1 and not no_cuda:
torch.cuda.set_device(local_rank)
device: torch.device = torch.device("cuda", local_rank)
n_gpu = 1
dist.init_process_group(backend="nccl")
elif not torch.cuda.is_available() or no_cuda:
device = torch.device("cpu")
n_gpu = 1
else:
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()
is_master = local_rank in (-1, 0)
return device, n_gpu, is_master