File size: 2,422 Bytes
fd52b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
from torch.distributions.kl import kl_divergence
from torch.distributions.normal import Normal
from torch.nn.functional import relu



class BatchHardTripletLoss(nn.Module):
    def __init__(self, margin=1., squared=False, agg='sum'):
        """
        Initalize the loss function with a margin parameter, whether or not to consider
        squared Euclidean distance and how to aggregate the loss in a batch
        """
        super().__init__()
        self.margin = margin
        self.squared = squared
        self.agg = agg
        self.eps = 1e-8

    def get_pairwise_distances(self, embeddings):
        """
        Computing Euclidean distance for all possible pairs of embeddings.
        """
        ab = embeddings.mm(embeddings.t())
        a_squared = ab.diag().unsqueeze(1)
        b_squared = ab.diag().unsqueeze(0)
        distances = a_squared - 2 * ab + b_squared
        distances = relu(distances)

        if not self.squared:
            distances = torch.sqrt(distances + self.eps)

        return distances

    def hardest_triplet_mining(self, dist_mat, labels):
        
        assert len(dist_mat.size()) == 2
        assert dist_mat.size(0) == dist_mat.size(1)
        N = dist_mat.size(0)
        
        is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
        is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())

        dist_ap, relative_p_inds = torch.max(
            (dist_mat * is_pos), 1, keepdim=True)
        
        dist_an, relative_n_inds = torch.min(
            (dist_mat * is_neg), 1, keepdim=True)

        return dist_ap, dist_an

    def forward(self, embeddings, labels):
        
        distances = self.get_pairwise_distances(embeddings)
        dist_ap, dist_an = self.hardest_triplet_mining(distances, labels)

        triplet_loss = relu(dist_ap - dist_an + self.margin).sum()
        return triplet_loss


class VAELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.reconstruction_loss = nn.BCELoss(reduction='sum')
    
    def kl_divergence_loss(self, q_dist):
        return kl_divergence(
            q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev))
        ).sum(-1)
    
    
    def forward(self, output, target, encoding):
        loss = self.kl_divergence_loss(encoding).sum() + self.reconstruction_loss(output, target)
        return loss