GANime / ganime /model /moving_vae.py
Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
No virus
4.52 kB
from tensorflow.keras import Model
import tensorflow as tf
import tensorflow_probability as tfp
class MovingVAE(Model):
def __init__(self, input_shape, encoded_size=64, base_depth=32):
super().__init__()
self.encoded_size = encoded_size
self.base_depth = base_depth
self.prior = tfp.distributions.Independent(
tfp.distributions.Normal(loc=tf.zeros(encoded_size), scale=1),
reinterpreted_batch_ndims=1,
)
self.encoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=input_shape),
tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
tf.keras.layers.Conv3D(
self.base_depth,
5,
strides=1,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3D(
self.base_depth,
5,
strides=2,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3D(
2 * self.base_depth,
5,
strides=1,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3D(
2 * self.base_depth,
5,
strides=2,
padding="same",
activation=tf.nn.leaky_relu,
),
# tf.keras.layers.Conv3D(4 * encoded_size, 7, strides=1,
# padding='valid', activation=tf.nn.leaky_relu),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
tfp.layers.MultivariateNormalTriL.params_size(self.encoded_size),
activation=None,
),
tfp.layers.MultivariateNormalTriL(
self.encoded_size,
activity_regularizer=tfp.layers.KLDivergenceRegularizer(self.prior),
),
]
)
self.decoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=[self.encoded_size]),
tf.keras.layers.Reshape([1, 1, 1, self.encoded_size]),
tf.keras.layers.Conv3DTranspose(
self.base_depth,
(5, 4, 4),
strides=1,
padding="valid",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
2 * self.base_depth,
(5, 4, 4),
strides=(1, 2, 2),
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
2 * self.base_depth,
(5, 4, 4),
strides=2,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
self.base_depth,
(5, 4, 4),
strides=(1, 2, 2),
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
self.base_depth,
(5, 4, 4),
strides=2,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
self.base_depth,
(5, 4, 4),
strides=1,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv2D(
filters=1, kernel_size=5, strides=1, padding="same", activation=None
),
tf.keras.layers.Flatten(),
tfp.layers.IndependentBernoulli(
input_shape, tfp.distributions.Bernoulli.logits
),
]
)
self.model = tf.keras.Model(
inputs=self.encoder.inputs, outputs=self.decoder(self.encoder.outputs[0])
)
def call(self, inputs):
return self.model(inputs)