Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
No virus
4.01 kB
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
input_shape = (20, 64, 64, 1)
class Sampling(keras.layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = z_mean.shape[1:]
epsilon = tf.keras.backend.random_normal(shape=(batch, *dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
def compute_output_shape(self, input_shape):
return input_shape[0]
class VAE(keras.Model):
def __init__(self, latent_dim:int=32, num_embeddings:int=128, beta:float = 0.5, **kwargs):
super().__init__(**kwargs)
self.latent_dim = latent_dim
self.num_embeddings = num_embeddings
self.beta = beta
self.encoder = self.get_encoder()
self.decoder = self.get_decoder()
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
def get_encoder(self):
encoder_inputs = keras.Input(shape=input_shape)
x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))(
encoder_inputs
)
x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x)
x = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x)
x = layers.TimeDistributed(layers.Flatten())(x)
mu = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
logvar = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
z = Sampling()([mu, logvar])
return keras.Model(encoder_inputs, [mu, logvar, z], name="encoder")
def get_decoder(self):
latent_inputs = keras.Input(shape=self.encoder.output[2].shape[1:])
x = layers.TimeDistributed(layers.Dense(16 * 16 * 32, activation="relu"))(latent_inputs)
x = layers.TimeDistributed(layers.Reshape((16, 16, 32)))(x)
x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))(
x
)
x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x)
decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x)
return keras.Model(latent_inputs, decoder_outputs, name="decoder")
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
mu, logvar, z = self.encoder(x)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
tf.keras.losses.binary_crossentropy(y, reconstruction),
axis=(1, 2),
)
)
kl_loss = -0.5 * (1 + logvar - tf.square(mu) - tf.exp(logvar))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + self.beta * kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def call(self, inputs, training=False, mask=None):
z_mean, z_log_var, z = self.encoder(inputs)
pred = self.decoder(z)
return pred