import torch from torch import nn class Loss_VAE(nn.Module): def __init__(self): super().__init__() self.mse = nn.MSELoss(reduction='sum') def forward(self, recon_x, x, mu, log_var): mse = self.mse(recon_x, x) kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) loss = mse + kld return loss def DiceScore( y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ) -> torch.Tensor: """Computes Dice score metric from full size Tensor and collects average. Args: y_pred: input data to compute, typical segmentation model output. It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values should be binarized. y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch. The values should be binarized. include_background: whether to skip Dice computation on the first channel of the predicted output. Defaults to True. Returns: Dice scores per batch and per class, (shape [batch_size, num_classes]). Raises: ValueError: when `y_pred` and `y` have different shapes. """ y = y.float() y_pred = y_pred.float() if y.shape != y_pred.shape: raise ValueError("y_pred and y should have same shapes.") # reducing only spatial dimensions (not batch nor channels) n_len = len(y_pred.shape) reduce_axis = list(range(2, n_len)) intersection = torch.sum(y * y_pred, dim=reduce_axis) y_o = torch.sum(y, reduce_axis) y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o return torch.where( denominator > 0, (2.0 * intersection) / denominator, torch.tensor(float("1"), device=y_o.device), )