from detectron2.engine import AMPTrainer import torch import time import logging logger = logging.getLogger("detectron2") import typing from collections import defaultdict import tabulate from torch import nn def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]: """ Count parameters of a model and its submodules. Args: model: a torch module Returns: dict (str-> int): the key is either a parameter name or a module name. The value is the number of elements in the parameter, or in all parameters of the module. The key "" corresponds to the total number of parameters of the model. """ r = defaultdict(int) for name, prm in model.named_parameters(): if trainable_only: if not prm.requires_grad: continue size = prm.numel() name = name.split(".") for k in range(0, len(name) + 1): prefix = ".".join(name[:k]) r[prefix] += size return r def parameter_count_table( model: nn.Module, max_depth: int = 3, trainable_only: bool = False ) -> str: """ Format the parameter count of the model (and its submodules or parameters) in a nice table. It looks like this: :: | name | #elements or shape | |:--------------------------------|:---------------------| | model | 37.9M | | backbone | 31.5M | | backbone.fpn_lateral3 | 0.1M | | backbone.fpn_lateral3.weight | (256, 512, 1, 1) | | backbone.fpn_lateral3.bias | (256,) | | backbone.fpn_output3 | 0.6M | | backbone.fpn_output3.weight | (256, 256, 3, 3) | | backbone.fpn_output3.bias | (256,) | | backbone.fpn_lateral4 | 0.3M | | backbone.fpn_lateral4.weight | (256, 1024, 1, 1) | | backbone.fpn_lateral4.bias | (256,) | | backbone.fpn_output4 | 0.6M | | backbone.fpn_output4.weight | (256, 256, 3, 3) | | backbone.fpn_output4.bias | (256,) | | backbone.fpn_lateral5 | 0.5M | | backbone.fpn_lateral5.weight | (256, 2048, 1, 1) | | backbone.fpn_lateral5.bias | (256,) | | backbone.fpn_output5 | 0.6M | | backbone.fpn_output5.weight | (256, 256, 3, 3) | | backbone.fpn_output5.bias | (256,) | | backbone.top_block | 5.3M | | backbone.top_block.p6 | 4.7M | | backbone.top_block.p7 | 0.6M | | backbone.bottom_up | 23.5M | | backbone.bottom_up.stem | 9.4K | | backbone.bottom_up.res2 | 0.2M | | backbone.bottom_up.res3 | 1.2M | | backbone.bottom_up.res4 | 7.1M | | backbone.bottom_up.res5 | 14.9M | | ...... | ..... | Args: model: a torch module max_depth (int): maximum depth to recursively print submodules or parameters Returns: str: the table to be printed """ count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only) # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. param_shape: typing.Dict[str, typing.Tuple] = { k: tuple(v.shape) for k, v in model.named_parameters() } # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. table: typing.List[typing.Tuple] = [] def format_size(x: int) -> str: if x > 1e8: return "{:.1f}G".format(x / 1e9) if x > 1e5: return "{:.1f}M".format(x / 1e6) if x > 1e2: return "{:.1f}K".format(x / 1e3) return str(x) def fill(lvl: int, prefix: str) -> None: if lvl >= max_depth: return for name, v in count.items(): if name.count(".") == lvl and name.startswith(prefix): indent = " " * (lvl + 1) if name in param_shape: table.append((indent + name, indent + str(param_shape[name]))) else: table.append((indent + name, indent + format_size(v))) fill(lvl + 1, name + ".") table.append(("model", format_size(count.pop("")))) fill(0, "") old_ws = tabulate.PRESERVE_WHITESPACE tabulate.PRESERVE_WHITESPACE = True tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe") tabulate.PRESERVE_WHITESPACE = old_ws return tab def cycle(iterable): while True: for x in iterable: yield x class MattingTrainer(AMPTrainer): def __init__(self, model, data_loader, optimizer, grad_scaler=None): super().__init__(model, data_loader, optimizer, grad_scaler=None) self.data_loader_iter = iter(cycle(self.data_loader)) # print model parameters logger.info("All parameters: \n" + parameter_count_table(model)) logger.info("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=8)) def run_step(self): """ Implement the AMP training logic. """ assert self.model.training, "[AMPTrainer] model was changed to eval mode!" assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" from torch.cuda.amp import autocast #matting pass start = time.perf_counter() data = next(self.data_loader_iter) data_time = time.perf_counter() - start with autocast(): loss_dict = self.model(data) if isinstance(loss_dict, torch.Tensor): losses = loss_dict loss_dict = {"total_loss": loss_dict} else: losses = sum(loss_dict.values()) self.optimizer.zero_grad() self.grad_scaler.scale(losses).backward() self._write_metrics(loss_dict, data_time) self.grad_scaler.step(self.optimizer) self.grad_scaler.update()