Fabrice-TIERCELIN commited on
Commit
54d7586
1 Parent(s): fce9266

Upload __init__.py

Browse files
sgm/modules/autoencoding/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ....modules.distributions.distributions import DiagonalGaussianDistribution
9
+
10
+
11
+ class AbstractRegularizer(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
16
+ raise NotImplementedError()
17
+
18
+ @abstractmethod
19
+ def get_trainable_parameters(self) -> Any:
20
+ raise NotImplementedError()
21
+
22
+
23
+ class DiagonalGaussianRegularizer(AbstractRegularizer):
24
+ def __init__(self, sample: bool = True):
25
+ super().__init__()
26
+ self.sample = sample
27
+
28
+ def get_trainable_parameters(self) -> Any:
29
+ yield from ()
30
+
31
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
32
+ log = dict()
33
+ posterior = DiagonalGaussianDistribution(z)
34
+ if self.sample:
35
+ z = posterior.sample()
36
+ else:
37
+ z = posterior.mode()
38
+ kl_loss = posterior.kl()
39
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
40
+ log["kl_loss"] = kl_loss
41
+ return z, log
42
+
43
+
44
+ def measure_perplexity(predicted_indices, num_centroids):
45
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
46
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
47
+ encodings = (
48
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
49
+ )
50
+ avg_probs = encodings.mean(0)
51
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
52
+ cluster_use = torch.sum(avg_probs > 0)
53
+ return perplexity, cluster_use