File size: 4,520 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from tensorflow.keras import Model

import tensorflow as tf
import tensorflow_probability as tfp


class MovingVAE(Model):
    def __init__(self, input_shape, encoded_size=64, base_depth=32):
        super().__init__()

        self.encoded_size = encoded_size
        self.base_depth = base_depth

        self.prior = tfp.distributions.Independent(
            tfp.distributions.Normal(loc=tf.zeros(encoded_size), scale=1),
            reinterpreted_batch_ndims=1,
        )

        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=input_shape),
                tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
                tf.keras.layers.Conv3D(
                    self.base_depth,
                    5,
                    strides=1,
                    padding="same",
                    activation=tf.nn.leaky_relu,
                ),
                tf.keras.layers.Conv3D(
                    self.base_depth,
                    5,
                    strides=2,
                    padding="same",
                    activation=tf.nn.leaky_relu,
                ),
                tf.keras.layers.Conv3D(
                    2 * self.base_depth,
                    5,
                    strides=1,
                    padding="same",
                    activation=tf.nn.leaky_relu,
                ),
                tf.keras.layers.Conv3D(
                    2 * self.base_depth,
                    5,
                    strides=2,
                    padding="same",
                    activation=tf.nn.leaky_relu,
                ),
                # tf.keras.layers.Conv3D(4 * encoded_size, 7, strides=1,
                #            padding='valid', activation=tf.nn.leaky_relu),
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(
                    tfp.layers.MultivariateNormalTriL.params_size(self.encoded_size),
                    activation=None,
                ),
                tfp.layers.MultivariateNormalTriL(
                    self.encoded_size,
                    activity_regularizer=tfp.layers.KLDivergenceRegularizer(self.prior),
                ),
            ]
        )

        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=[self.encoded_size]),
                tf.keras.layers.Reshape([1, 1, 1, self.encoded_size]),
                tf.keras.layers.Conv3DTranspose(
                    self.base_depth,
                    (5, 4, 4),
                    strides=1,
                    padding="valid",
                    activation=tf.nn.leaky_relu,
                ),
                tf.keras.layers.Conv3DTranspose(
                    2 * self.base_depth,
                    (5, 4, 4),
                    strides=(1, 2, 2),
                    padding="same",
                    activation=tf.nn.leaky_relu,
                ),
                tf.keras.layers.Conv3DTranspose(
                    2 * self.base_depth,
                    (5, 4, 4),
                    strides=2,
                    padding="same",
                    activation=tf.nn.leaky_relu,
                ),
                tf.keras.layers.Conv3DTranspose(
                    self.base_depth,
                    (5, 4, 4),
                    strides=(1, 2, 2),
                    padding="same",
                    activation=tf.nn.leaky_relu,
                ),
                tf.keras.layers.Conv3DTranspose(
                    self.base_depth,
                    (5, 4, 4),
                    strides=2,
                    padding="same",
                    activation=tf.nn.leaky_relu,
                ),
                tf.keras.layers.Conv3DTranspose(
                    self.base_depth,
                    (5, 4, 4),
                    strides=1,
                    padding="same",
                    activation=tf.nn.leaky_relu,
                ),
                tf.keras.layers.Conv2D(
                    filters=1, kernel_size=5, strides=1, padding="same", activation=None
                ),
                tf.keras.layers.Flatten(),
                tfp.layers.IndependentBernoulli(
                    input_shape, tfp.distributions.Bernoulli.logits
                ),
            ]
        )

        self.model = tf.keras.Model(
            inputs=self.encoder.inputs, outputs=self.decoder(self.encoder.outputs[0])
        )

    def call(self, inputs):
        return self.model(inputs)