DuyTa commited on
Commit
5c3971f
1 Parent(s): 24a5ff0

Delete loss

Browse files
loss/__init__.py DELETED
File without changes
loss/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (124 Bytes)
 
loss/__pycache__/loss.cpython-39.pyc DELETED
Binary file (2.16 kB)
 
loss/loss.py DELETED
@@ -1,55 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- class Loss_VAE(nn.Module):
5
- def __init__(self):
6
- super().__init__()
7
- self.mse = nn.MSELoss(reduction='sum')
8
-
9
- def forward(self, recon_x, x, mu, log_var):
10
- mse = self.mse(recon_x, x)
11
- kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
12
- loss = mse + kld
13
- return loss
14
-
15
-
16
- def DiceScore(
17
- y_pred: torch.Tensor,
18
- y: torch.Tensor,
19
- include_background: bool = True,
20
- ) -> torch.Tensor:
21
- """Computes Dice score metric from full size Tensor and collects average.
22
- Args:
23
- y_pred: input data to compute, typical segmentation model output.
24
- It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
25
- should be binarized.
26
- y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
27
- The values should be binarized.
28
- include_background: whether to skip Dice computation on the first channel of
29
- the predicted output. Defaults to True.
30
- Returns:
31
- Dice scores per batch and per class, (shape [batch_size, num_classes]).
32
- Raises:
33
- ValueError: when `y_pred` and `y` have different shapes.
34
- """
35
-
36
- y = y.float()
37
- y_pred = y_pred.float()
38
-
39
- if y.shape != y_pred.shape:
40
- raise ValueError("y_pred and y should have same shapes.")
41
-
42
- # reducing only spatial dimensions (not batch nor channels)
43
- n_len = len(y_pred.shape)
44
- reduce_axis = list(range(2, n_len))
45
- intersection = torch.sum(y * y_pred, dim=reduce_axis)
46
-
47
- y_o = torch.sum(y, reduce_axis)
48
- y_pred_o = torch.sum(y_pred, dim=reduce_axis)
49
- denominator = y_o + y_pred_o
50
-
51
- return torch.where(
52
- denominator > 0,
53
- (2.0 * intersection) / denominator,
54
- torch.tensor(float("1"), device=y_o.device),
55
- )