from statistics import mode import numpy as np import tensorflow as tf from tensorflow.python.keras import Model, Sequential from tensorflow.python.keras.layers import Dense, LSTMCell, RNN, Conv2D, Conv2DTranspose from tensorflow.keras.layers import BatchNormalization, TimeDistributed from tensorflow.python.keras.layers.advanced_activations import LeakyReLU from tensorflow.keras.layers import Activation # from tensorflow_probability.python.layers.dense_variational import ( # DenseReparameterization, # ) # import tensorflow_probability as tfp from tensorflow.keras.losses import Loss class KLCriterion(Loss): def call(self, y_true, y_pred): (mu1, logvar1), (mu2, logvar2) = y_true, y_pred """KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))""" sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5)) sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5)) kld = ( tf.math.log(sigma2 / sigma1) + (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2)) - 0.5 ) return tf.reduce_sum(kld) / 22 class Encoder(Model): def __init__(self, dim, nc=1): super().__init__() self.dim = dim self.c1 = Sequential( [ Conv2D(64, kernel_size=4, strides=2, padding="same"), BatchNormalization(), LeakyReLU(alpha=0.2), ] ) self.c2 = Sequential( [ Conv2D(128, kernel_size=4, strides=2, padding="same"), BatchNormalization(), LeakyReLU(alpha=0.2), ] ) self.c3 = Sequential( [ Conv2D(256, kernel_size=4, strides=2, padding="same"), BatchNormalization(), LeakyReLU(alpha=0.2), ] ) self.c4 = Sequential( [ Conv2D(512, kernel_size=4, strides=2, padding="same"), BatchNormalization(), LeakyReLU(alpha=0.2), ] ) self.c5 = Sequential( [ Conv2D(self.dim, kernel_size=4, strides=1, padding="valid"), BatchNormalization(), Activation("tanh"), ] ) def call(self, input): h1 = self.c1(input) h2 = self.c2(h1) h3 = self.c3(h2) h4 = self.c4(h3) h5 = self.c5(h4) return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5] class Decoder(Model): def __init__(self, dim, nc=1): super().__init__() self.dim = dim self.upc1 = Sequential( [ Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid"), BatchNormalization(), LeakyReLU(alpha=0.2), ] ) self.upc2 = Sequential( [ Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"), BatchNormalization(), LeakyReLU(alpha=0.2), ] ) self.upc3 = Sequential( [ Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"), BatchNormalization(), LeakyReLU(alpha=0.2), ] ) self.upc4 = Sequential( [ Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"), BatchNormalization(), LeakyReLU(alpha=0.2), ] ) self.upc5 = Sequential( [ Conv2DTranspose(1, kernel_size=4, strides=2, padding="same"), Activation("sigmoid"), ] ) def call(self, input): vec, skip = input d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim))) d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1)) d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1)) d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1)) output = self.upc5(tf.concat([d4, skip[0]], axis=-1)) return output class MyLSTM(Model): def __init__(self, input_shape, hidden_size, output_size, n_layers): super().__init__() self.hidden_size = hidden_size self.n_layers = n_layers self.embed = Dense(hidden_size, input_dim=input_shape) # self.lstm = Sequential( # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm" # ) # self.lstm = self.create_lstm(hidden_size, n_layers) self.lstm = LSTMCell(hidden_size) self.out = Dense(output_size) def init_hidden(self, batch_size): hidden = [] for i in range(self.n_layers): hidden.append( ( tf.Variable(tf.zeros([batch_size, self.hidden_size])), tf.Variable(tf.zeros([batch_size, self.hidden_size])), ) ) self.__dict__["hidden"] = hidden def build(self, input_shape): self.init_hidden(input_shape[0]) def call(self, inputs): h_in = self.embed(inputs) for i in range(self.n_layers): _, self.hidden[i] = self.lstm(h_in, self.hidden[i]) h_in = self.hidden[i][0] return self.out(h_in) class MyGaussianLSTM(Model): def __init__(self, input_shape, hidden_size, output_size, n_layers): super().__init__() self.hidden_size = hidden_size self.n_layers = n_layers self.embed = Dense(hidden_size, input_dim=input_shape) # self.lstm = Sequential( # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm" # ) self.lstm = LSTMCell(hidden_size) self.mu_net = Dense(output_size) self.logvar_net = Dense(output_size) # self.out = Sequential( # [ # tf.keras.layers.Dense( # tfp.layers.MultivariateNormalTriL.params_size(output_size), # activation=None, # ), # tfp.layers.MultivariateNormalTriL(output_size), # ] # ) def reparameterize(self, mu, logvar: tf.Tensor): logvar = tf.math.exp(logvar * 0.5) eps = tf.random.normal(logvar.shape) return tf.add(tf.math.multiply(eps, logvar), mu) def init_hidden(self, batch_size): hidden = [] for i in range(self.n_layers): hidden.append( ( tf.Variable(tf.zeros([batch_size, self.hidden_size])), tf.Variable(tf.zeros([batch_size, self.hidden_size])), ) ) self.__dict__["hidden"] = hidden def build(self, input_shape): self.init_hidden(input_shape[0]) def call(self, inputs): h_in = self.embed(inputs) for i in range(self.n_layers): # print(h_in.shape, self.hidden[i][0].shape, self.hidden[i][0].shape) _, self.hidden[i] = self.lstm(h_in, self.hidden[i]) h_in = self.hidden[i][0] mu = self.mu_net(h_in) logvar = self.logvar_net(h_in) z = self.reparameterize(mu, logvar) return z, mu, logvar class P2P(Model): def __init__( self, channels: int = 1, g_dim: int = 128, z_dim: int = 10, rnn_size: int = 256, prior_rnn_layers: int = 1, posterior_rnn_layers: int = 1, predictor_rnn_layers: float = 1, skip_prob: float = 0.5, n_past: int = 1, last_frame_skip: bool = False, beta: float = 0.0001, weight_align: float = 0.1, weight_cpc: float = 100, ): super().__init__() self.channels = channels self.g_dim = g_dim self.z_dim = z_dim self.rnn_size = rnn_size self.prior_rnn_layers = prior_rnn_layers self.posterior_rnn_layers = posterior_rnn_layers self.predictor_rnn_layers = predictor_rnn_layers self.skip_prob = skip_prob self.n_past = n_past self.last_frame_skip = last_frame_skip self.beta = beta self.weight_align = weight_align self.weight_cpc = weight_cpc self.frame_predictor = MyLSTM( self.g_dim + self.z_dim + 1 + 1, self.rnn_size, self.g_dim, self.predictor_rnn_layers, ) self.prior = MyGaussianLSTM( self.g_dim + self.g_dim + 1 + 1, self.rnn_size, self.z_dim, self.prior_rnn_layers, ) self.posterior = MyGaussianLSTM( self.g_dim + self.g_dim + 1 + 1, self.rnn_size, self.z_dim, self.posterior_rnn_layers, ) self.encoder = Encoder(self.g_dim, self.channels) self.decoder = Decoder(self.g_dim, self.channels) # criterions self.mse_criterion = tf.keras.losses.MeanSquaredError() self.kl_criterion = KLCriterion() self.align_criterion = tf.keras.losses.MeanSquaredError() # optimizers self.frame_predictor_optimizer = tf.keras.optimizers.Adam( learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 ) self.posterior_optimizer = tf.keras.optimizers.Adam( learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 ) self.prior_optimizer = tf.keras.optimizers.Adam( learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 ) self.encoder_optimizer = tf.keras.optimizers.Adam( learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 ) self.decoder_optimizer = tf.keras.optimizers.Adam( learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 ) def get_global_descriptor(self, x, start_ix=0, cp_ix=None): """Get the global descriptor based on x, start_ix, cp_ix.""" if cp_ix is None: cp_ix = x.shape[1] - 1 x_cp = x[:, cp_ix, ...] h_cp = self.encoder(x_cp)[0] # 1 is input for skip-connection return x_cp, h_cp def call(self, x, start_ix=0, cp_ix=-1): batch_size = x.shape[0] with tf.GradientTape(persistent=True) as tape: mse_loss = 0 kld_loss = 0 cpc_loss = 0 align_loss = 0 seq_len = x.shape[1] start_ix = 0 cp_ix = seq_len - 1 x_cp, global_z = self.get_global_descriptor( x, start_ix, cp_ix ) # here global_z is h_cp skip_prob = self.skip_prob prev_i = 0 max_skip_count = seq_len * skip_prob skip_count = 0 probs = np.random.uniform(low=0, high=1, size=seq_len - 1) for i in range(1, seq_len): if ( probs[i - 1] <= skip_prob and i >= self.n_past and skip_count < max_skip_count and i != 1 and i != cp_ix ): skip_count += 1 continue time_until_cp = tf.fill([batch_size, 1], (cp_ix - i + 1) / cp_ix) delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix)) prev_i = i h = self.encoder(x[:, i - 1, ...]) h_target = self.encoder(x[:, i, ...])[0] if self.last_frame_skip or i <= self.n_past: h, skip = h else: h = h[0] # Control Point Aware h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1) h_target_cpaw = tf.concat( [h_target, global_z, time_until_cp, delta_time], axis=1 ) zt, mu, logvar = self.posterior(h_target_cpaw) zt_p, mu_p, logvar_p = self.prior(h_cpaw) concat = tf.concat([h, zt, time_until_cp, delta_time], axis=1) h_pred = self.frame_predictor(concat) x_pred = self.decoder([h_pred, skip]) if i == cp_ix: # the gen-cp-frame should be exactly as x_cp h_pred_p = self.frame_predictor( tf.concat([h, zt_p, time_until_cp, delta_time], axis=1) ) x_pred_p = self.decoder([h_pred_p, skip]) cpc_loss = self.mse_criterion(x_pred_p, x_cp) if i > 1: align_loss += self.align_criterion(h[0], h_pred) mse_loss += self.mse_criterion(x_pred, x[:, i, ...]) kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p)) # backward loss = mse_loss + kld_loss * self.beta + align_loss * self.weight_align prior_loss = kld_loss + cpc_loss * self.weight_cpc var_list_frame_predictor = self.frame_predictor.trainable_variables var_list_posterior = self.posterior.trainable_variables var_list_prior = self.prior.trainable_variables var_list_encoder = self.encoder.trainable_variables var_list_decoder = self.decoder.trainable_variables # mse: frame_predictor + decoder # align: frame_predictor + encoder # kld: posterior + prior + encoder var_list_without_prior = ( var_list_frame_predictor + var_list_posterior + var_list_encoder + var_list_decoder ) gradients_without_prior = tape.gradient( loss, var_list_without_prior, ) gradients_prior = tape.gradient( prior_loss, var_list_prior, ) self.update_model_without_prior( gradients_without_prior, var_list_without_prior, ) self.update_prior(gradients_prior, var_list_prior) del tape return ( mse_loss / seq_len, kld_loss / seq_len, cpc_loss / seq_len, align_loss / seq_len, ) def p2p_generate( self, x, len_output, eval_cp_ix, start_ix=0, cp_ix=-1, model_mode="full", skip_frame=False, init_hidden=True, ): batch_size, num_frames, h, w, channels = x.shape dim_shape = (h, w, channels) gen_seq = [x[:, 0, ...]] x_in = x[:, 0, ...] seq_len = x.shape[1] cp_ix = seq_len - 1 x_cp, global_z = self.get_global_descriptor( x, cp_ix=cp_ix ) # here global_z is h_cp skip_prob = self.skip_prob prev_i = 0 max_skip_count = seq_len * skip_prob skip_count = 0 probs = np.random.uniform(0, 1, len_output - 1) for i in range(1, len_output): if ( probs[i - 1] <= skip_prob and i >= self.n_past and skip_count < max_skip_count and i != 1 and i != (len_output - 1) and skip_frame ): skip_count += 1 gen_seq.append(tf.zeros_like(x_in)) continue time_until_cp = tf.fill([batch_size, 1], (eval_cp_ix - i + 1) / eval_cp_ix) delta_time = tf.fill([batch_size, 1], ((i - prev_i) / eval_cp_ix)) prev_i = i h = self.encoder(x_in) if self.last_frame_skip or i == 1 or i < self.n_past: h, skip = h else: h, _ = h h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1) if i < self.n_past: h_target = self.encoder(x[:, i, ...])[0] h_target_cpaw = tf.concat( [h_target, global_z, time_until_cp, delta_time], axis=1 ) zt, _, _ = self.posterior(h_target_cpaw) zt_p, _, _ = self.prior(h_cpaw) if model_mode == "posterior" or model_mode == "full": self.frame_predictor( tf.concat([h, zt, time_until_cp, delta_time], axis=1) ) elif model_mode == "prior": self.frame_predictor( tf.concat([h, zt_p, time_until_cp, delta_time], axis=1) ) x_in = x[:, i, ...] gen_seq.append(x_in) else: if i < num_frames: h_target = self.encoder(x[:, i, ...])[0] h_target_cpaw = tf.concat( [h_target, global_z, time_until_cp, delta_time], axis=1 ) else: h_target_cpaw = h_cpaw zt, _, _ = self.posterior(h_target_cpaw) zt_p, _, _ = self.prior(h_cpaw) if model_mode == "posterior": h = self.frame_predictor( tf.concat([h, zt, time_until_cp, delta_time], axis=1) ) elif model_mode == "prior" or model_mode == "full": h = self.frame_predictor( tf.concat([h, zt_p, time_until_cp, delta_time], axis=1) ) x_in = self.decoder([h, skip]) gen_seq.append(x_in) return tf.stack(gen_seq, axis=1) def update_model_without_prior(self, gradients, var_list): self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list)) self.posterior_optimizer.apply_gradients(zip(gradients, var_list)) self.encoder_optimizer.apply_gradients(zip(gradients, var_list)) self.decoder_optimizer.apply_gradients(zip(gradients, var_list)) def update_prior(self, gradients, var_list): self.prior_optimizer.apply_gradients(zip(gradients, var_list)) # def update_model_without_prior(self): # self.frame_predictor_optimizer.step() # self.posterior_optimizer.step() # self.encoder_optimizer.step() # self.decoder_optimizer.step()