import logging import torch from torch import Tensor, nn logger = logging.getLogger(__name__) class Normalizer(nn.Module): def __init__(self, momentum=0.01, eps=1e-9): super().__init__() self.momentum = momentum self.eps = eps self.running_mean_unsafe: Tensor self.running_var_unsafe: Tensor self.register_buffer("running_mean_unsafe", torch.full([], torch.nan)) self.register_buffer("running_var_unsafe", torch.full([], torch.nan)) @property def started(self): return not torch.isnan(self.running_mean_unsafe) @property def running_mean(self): if not self.started: return torch.zeros_like(self.running_mean_unsafe) return self.running_mean_unsafe @property def running_std(self): if not self.started: return torch.ones_like(self.running_var_unsafe) return (self.running_var_unsafe + self.eps).sqrt() @torch.no_grad() def _ema(self, a: Tensor, x: Tensor): return (1 - self.momentum) * a + self.momentum * x def update_(self, x): if not self.started: self.running_mean_unsafe = x.mean() self.running_var_unsafe = x.var() else: self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean()) self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean()) def forward(self, x: Tensor, update=True): if self.training and update: self.update_(x) self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item()) x = (x - self.running_mean) / self.running_std return x def inverse(self, x: Tensor): return x * self.running_std + self.running_mean