import numpy as np import tensorflow as tf from tensorflow.keras import Model, Sequential from tensorflow.keras.layers import ( LSTM, Activation, BatchNormalization, Conv2D, Conv2DTranspose, Conv3D, Conv3DTranspose, Dense, Flatten, Input, Layer, LeakyReLU, MaxPooling2D, Reshape, TimeDistributed, UpSampling2D, ) SEQ_LEN = 20 class Sampling(Layer): """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" def call(self, inputs): z_mean, z_log_var = inputs batch = tf.shape(z_mean)[0] dim = tf.shape(z_mean)[1] epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) return z_mean + tf.exp(0.5 * z_log_var) * epsilon def compute_output_shape(self, input_shape): return input_shape[0] class P2P(Model): def __init__( self, channels: int = 1, g_dim: int = 128, z_dim: int = 10, rnn_size: int = 256, prior_rnn_layers: int = 1, posterior_rnn_layers: int = 1, predictor_rnn_layers: float = 1, skip_prob: float = 0.1, n_past: int = 1, last_frame_skip: bool = False, beta: float = 0.0001, weight_align: float = 0.1, weight_cpc: float = 100, ): super().__init__() # Models parameters self.channels = channels self.g_dim = g_dim self.z_dim = z_dim self.rnn_size = rnn_size self.prior_rnn_layers = prior_rnn_layers self.posterior_rnn_layers = posterior_rnn_layers self.predictor_rnn_layers = predictor_rnn_layers # Training parameters self.skip_prob = skip_prob self.n_past = n_past self.last_frame_skip = last_frame_skip self.beta = beta self.weight_align = weight_align self.weight_cpc = weight_cpc self.frame_predictor = self.build_lstm() self.prior = self.build_gaussian_lstm() self.posterior = self.build_gaussian_lstm() self.encoder = self.build_encoder() self.decoder = self.build_decoder() self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") self.reconstruction_loss_tracker = tf.keras.metrics.Mean( name="reconstruction_loss" ) self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") # region Model building def build_lstm(self): input = Input(shape=(20, self.g_dim + self.z_dim + 1)) embed = TimeDistributed(Dense(self.rnn_size))(input) lstm = LSTM(self.rnn_size, return_sequences=True)(embed) output = TimeDistributed(Dense(self.g_dim))(lstm) return Model(inputs=input, outputs=output, name="frame_predictor") def build_gaussian_lstm(self): input = Input(shape=(20, self.g_dim)) embed = TimeDistributed(Dense(self.rnn_size))(input) lstm = LSTM(self.rnn_size, return_sequences=True)(embed) mu = TimeDistributed(Dense(self.z_dim))(lstm) logvar = TimeDistributed(Dense(self.z_dim))(lstm) z = TimeDistributed(Sampling())([mu, logvar]) return Model(inputs=input, outputs=[mu, logvar, z]) def build_encoder(self): input = Input(shape=(2, 64, 64, 1)) h = TimeDistributed(Conv2D(64, kernel_size=4, strides=2, padding="same"))(input) h = BatchNormalization()(h) h = LeakyReLU(alpha=0.2)(h) # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) h = TimeDistributed(Conv2D(128, kernel_size=4, strides=2, padding="same"))(h) h = BatchNormalization()(h) h = LeakyReLU(alpha=0.2)(h) # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) h = TimeDistributed(Conv2D(256, kernel_size=4, strides=2, padding="same"))(h) h = BatchNormalization()(h) h = LeakyReLU(alpha=0.2)(h) # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) h = TimeDistributed(Conv2D(512, kernel_size=4, strides=2, padding="same"))(h) h = BatchNormalization()(h) h = LeakyReLU(alpha=0.2)(h) # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) h = Flatten()(h) # mu = Dense(self.g_dim)(h) # logvar = Dense(self.g_dim)(h) # z = Sampling()([mu, logvar]) lstm_input = Dense(self.g_dim * SEQ_LEN)(h) lstm_input = Reshape((SEQ_LEN, self.g_dim))(lstm_input) mu, logvar, z = self.posterior(lstm_input) return Model(inputs=input, outputs=[mu, logvar, z], name="encoder") def build_decoder(self): latent_inputs = Input(shape=(SEQ_LEN, self.z_dim)) x = Dense(1 * 1 * 1 * 512, activation="relu")(latent_inputs) x = Reshape((SEQ_LEN, 1, 1, 512))(x) x = TimeDistributed( Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid") )(x) x = BatchNormalization()(x) x = LeakyReLU(alpha=0.2)(x) x = TimeDistributed( Conv2DTranspose(256, kernel_size=4, strides=2, padding="same") )(x) x = BatchNormalization()(x) x = LeakyReLU(alpha=0.2)(x) x = TimeDistributed( Conv2DTranspose(128, kernel_size=4, strides=2, padding="same") )(x) x = BatchNormalization()(x) x = LeakyReLU(alpha=0.2)(x) x = TimeDistributed( Conv2DTranspose(64, kernel_size=4, strides=2, padding="same") )(x) x = BatchNormalization()(x) x = LeakyReLU(alpha=0.2)(x) x = TimeDistributed( Conv2DTranspose(1, kernel_size=4, strides=2, padding="same") )(x) x = Activation("sigmoid")(x) return Model(inputs=latent_inputs, outputs=x, name="decoder") # endregion @property def metrics(self): return [ self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker, ] def call(self, inputs, training=None, mask=None): z_mean, z_log_var, z = self.encoder(inputs) pred = self.decoder(z) return pred def train_step(self, data): x, y = data with tf.GradientTape() as tape: z_mean, z_log_var, z = self.encoder(x) reconstruction = self.decoder(z) reconstruction_loss = tf.reduce_mean( tf.reduce_sum( tf.keras.losses.binary_crossentropy(y, reconstruction), axis=(1, 2), ) ) kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) total_loss = reconstruction_loss + self.beta * kl_loss grads = tape.gradient(total_loss, self.trainable_weights) self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), } def test_step(self, data): if isinstance(data, tuple): data = data[0] z_mean, z_log_var, z = self.encoder(data) reconstruction = self.decoder(z) reconstruction_loss = tf.reduce_mean( tf.keras.losses.binary_crossentropy(data, reconstruction) ) reconstruction_loss *= 28 * 28 kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var) kl_loss = tf.reduce_mean(kl_loss) kl_loss *= -0.5 total_loss = reconstruction_loss + kl_loss return { "loss": total_loss, "reconstruction_loss": reconstruction_loss, "kl_loss": kl_loss, }