AItool's picture
Upload 127 files
a983ebc
raw
history blame
No virus
11.5 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01a_losses.ipynb.
# %% ../nbs/01a_losses.ipynb 2
from __future__ import annotations
from .imports import *
from .torch_imports import *
from .torch_core import *
from .layers import *
# %% auto 0
__all__ = ['BaseLoss', 'CrossEntropyLossFlat', 'FocalLoss', 'FocalLossFlat', 'BCEWithLogitsLossFlat', 'BCELossFlat',
'MSELossFlat', 'L1LossFlat', 'LabelSmoothingCrossEntropy', 'LabelSmoothingCrossEntropyFlat', 'DiceLoss']
# %% ../nbs/01a_losses.ipynb 5
class BaseLoss():
"Same as `loss_cls`, but flattens input and target."
activation=decodes=noops
def __init__(self,
loss_cls, # Uninitialized PyTorch-compatible loss
*args,
axis:int=-1, # Class axis
flatten:bool=True, # Flatten `inp` and `targ` before calculating loss
floatify:bool=False, # Convert `targ` to `float`
is_2d:bool=True, # Whether `flatten` keeps one or two channels when applied
**kwargs
):
store_attr("axis,flatten,floatify,is_2d")
self.func = loss_cls(*args,**kwargs)
functools.update_wrapper(self, self.func)
def __repr__(self) -> str: return f"FlattenedLoss of {self.func}"
@property
def reduction(self) -> str: return self.func.reduction
@reduction.setter
def reduction(self, v:str):
"Sets the reduction style (typically 'mean', 'sum', or 'none')"
self.func.reduction = v
def _contiguous(self, x:Tensor) -> TensorBase:
"Move `self.axis` to the last dimension and ensure tensor is contigous for `Tensor` otherwise just return"
return TensorBase(x.transpose(self.axis,-1).contiguous()) if isinstance(x,torch.Tensor) else x
def __call__(self,
inp:Tensor|MutableSequence, # Predictions from a `Learner`
targ:Tensor|MutableSequence, # Actual y label
**kwargs
) -> TensorBase: # `loss_cls` calculated on `inp` and `targ`
inp,targ = map(self._contiguous, (inp,targ))
if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
def to(self, device:torch.device):
"Move the loss function to a specified `device`"
if isinstance(self.func, nn.Module): self.func.to(device)
# %% ../nbs/01a_losses.ipynb 8
@delegates()
class CrossEntropyLossFlat(BaseLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
y_int = True # y interpolation
@use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
def __init__(self,
*args,
axis:int=-1, # Class axis
**kwargs
):
super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
def decodes(self, x:Tensor) -> Tensor:
"Converts model output to target format"
return x.argmax(dim=self.axis)
def activation(self, x:Tensor) -> Tensor:
"`nn.CrossEntropyLoss`'s fused activation function applied to model output"
return F.softmax(x, dim=self.axis)
# %% ../nbs/01a_losses.ipynb 13
class FocalLoss(Module):
y_int=True # y interpolation
def __init__(self,
gamma:float=2.0, # Focusing parameter. Higher values down-weight easy examples' contribution to loss
weight:Tensor=None, # Manual rescaling weight given to each class
reduction:str='mean' # PyTorch reduction to apply to the output
):
"Applies Focal Loss: https://arxiv.org/pdf/1708.02002.pdf"
store_attr()
def forward(self, inp:Tensor, targ:Tensor) -> Tensor:
"Applies focal loss based on https://arxiv.org/pdf/1708.02002.pdf"
ce_loss = F.cross_entropy(inp, targ, weight=self.weight, reduction="none")
p_t = torch.exp(-ce_loss)
loss = (1 - p_t)**self.gamma * ce_loss
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss
class FocalLossFlat(BaseLoss):
"""
Same as CrossEntropyLossFlat but with focal paramter, `gamma`. Focal loss is introduced by Lin et al.
https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be
implemented through pytorch `weight` argument passed through to F.cross_entropy.
"""
y_int = True # y interpolation
@use_kwargs_dict(keep=True, weight=None, reduction='mean')
def __init__(self,
*args,
gamma:float=2.0, # Focusing parameter. Higher values down-weight easy examples' contribution to loss
axis:int=-1, # Class axis
**kwargs
):
super().__init__(FocalLoss, *args, gamma=gamma, axis=axis, **kwargs)
def decodes(self, x:Tensor) -> Tensor:
"Converts model output to target format"
return x.argmax(dim=self.axis)
def activation(self, x:Tensor) -> Tensor:
"`F.cross_entropy`'s fused activation function applied to model output"
return F.softmax(x, dim=self.axis)
# %% ../nbs/01a_losses.ipynb 16
@delegates()
class BCEWithLogitsLossFlat(BaseLoss):
"Same as `nn.BCEWithLogitsLoss`, but flattens input and target."
@use_kwargs_dict(keep=True, weight=None, reduction='mean', pos_weight=None)
def __init__(self,
*args,
axis:int=-1, # Class axis
floatify:bool=True, # Convert `targ` to `float`
thresh:float=0.5, # The threshold on which to predict
**kwargs
):
if kwargs.get('pos_weight', None) is not None and kwargs.get('flatten', None) is True:
raise ValueError("`flatten` must be False when using `pos_weight` to avoid a RuntimeError due to shape mismatch")
if kwargs.get('pos_weight', None) is not None: kwargs['flatten'] = False
super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
self.thresh = thresh
def decodes(self, x:Tensor) -> Tensor:
"Converts model output to target format"
return x>self.thresh
def activation(self, x:Tensor) -> Tensor:
"`nn.BCEWithLogitsLoss`'s fused activation function applied to model output"
return torch.sigmoid(x)
# %% ../nbs/01a_losses.ipynb 18
@use_kwargs_dict(weight=None, reduction='mean')
def BCELossFlat(
*args,
axis:int=-1, # Class axis
floatify:bool=True, # Convert `targ` to `float`
**kwargs
):
"Same as `nn.BCELoss`, but flattens input and target."
return BaseLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
# %% ../nbs/01a_losses.ipynb 20
@use_kwargs_dict(reduction='mean')
def MSELossFlat(
*args,
axis:int=-1, # Class axis
floatify:bool=True, # Convert `targ` to `float`
**kwargs
):
"Same as `nn.MSELoss`, but flattens input and target."
return BaseLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
# %% ../nbs/01a_losses.ipynb 23
@use_kwargs_dict(reduction='mean')
def L1LossFlat(
*args,
axis=-1, # Class axis
floatify=True, # Convert `targ` to `float`
**kwargs
):
"Same as `nn.L1Loss`, but flattens input and target."
return BaseLoss(nn.L1Loss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
# %% ../nbs/01a_losses.ipynb 24
class LabelSmoothingCrossEntropy(Module):
y_int = True # y interpolation
def __init__(self,
eps:float=0.1, # The weight for the interpolation formula
weight:Tensor=None, # Manual rescaling weight given to each class passed to `F.nll_loss`
reduction:str='mean' # PyTorch reduction to apply to the output
):
store_attr()
def forward(self, output:Tensor, target:Tensor) -> Tensor:
"Apply `F.log_softmax` on output then blend the loss/num_classes(`c`) with the `F.nll_loss`"
c = output.size()[1]
log_preds = F.log_softmax(output, dim=1)
if self.reduction=='sum': loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=1) #We divide by that size at the return line so sum and not mean
if self.reduction=='mean': loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target.long(), weight=self.weight, reduction=self.reduction)
def activation(self, out:Tensor) -> Tensor:
"`F.log_softmax`'s fused activation function applied to model output"
return F.softmax(out, dim=-1)
def decodes(self, out:Tensor) -> Tensor:
"Converts model output to target format"
return out.argmax(dim=-1)
# %% ../nbs/01a_losses.ipynb 27
@delegates()
class LabelSmoothingCrossEntropyFlat(BaseLoss):
"Same as `LabelSmoothingCrossEntropy`, but flattens input and target."
y_int = True
@use_kwargs_dict(keep=True, eps=0.1, reduction='mean')
def __init__(self,
*args,
axis:int=-1, # Class axis
**kwargs
):
super().__init__(LabelSmoothingCrossEntropy, *args, axis=axis, **kwargs)
def activation(self, out:Tensor) -> Tensor:
"`LabelSmoothingCrossEntropy`'s fused activation function applied to model output"
return F.softmax(out, dim=-1)
def decodes(self, out:Tensor) -> Tensor:
"Converts model output to target format"
return out.argmax(dim=-1)
# %% ../nbs/01a_losses.ipynb 30
class DiceLoss:
"Dice loss for segmentation"
def __init__(self,
axis:int=1, # Class axis
smooth:float=1e-6, # Helps with numerical stabilities in the IoU division
reduction:str="sum", # PyTorch reduction to apply to the output
square_in_union:bool=False # Squares predictions to increase slope of gradients
):
store_attr()
def __call__(self, pred:Tensor, targ:Tensor) -> Tensor:
"One-hot encodes targ, then runs IoU calculation then takes 1-dice value"
targ = self._one_hot(targ, pred.shape[self.axis])
pred, targ = TensorBase(pred), TensorBase(targ)
assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs'
pred = self.activation(pred)
sum_dims = list(range(2, len(pred.shape)))
inter = torch.sum(pred*targ, dim=sum_dims)
union = (torch.sum(pred**2+targ, dim=sum_dims) if self.square_in_union
else torch.sum(pred+targ, dim=sum_dims))
dice_score = (2. * inter + self.smooth)/(union + self.smooth)
loss = 1- dice_score
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
@staticmethod
def _one_hot(
x:Tensor, # Non one-hot encoded targs
classes:int, # The number of classes
axis:int=1 # The axis to stack for encoding (class dimension)
) -> Tensor:
"Creates one binary mask per class"
return torch.stack([torch.where(x==c, 1, 0) for c in range(classes)], axis=axis)
def activation(self, x:Tensor) -> Tensor:
"Activation function applied to model output"
return F.softmax(x, dim=self.axis)
def decodes(self, x:Tensor) -> Tensor:
"Converts model output to target format"
return x.argmax(dim=self.axis)