import torch.nn as nn from torch.distributions.normal import Normal from .constants import * class Encoder(nn.Module): ''' Encoder Class Values: im_chan: the number of channels of the output image, a scalar hidden_dim: the inner dimension, a scalar ''' def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM): super(Encoder, self).__init__() self.z_dim = output_chan self.disc = nn.Sequential( self.make_disc_block(im_chan, hidden_dim), self.make_disc_block(hidden_dim, hidden_dim * 2), self.make_disc_block(hidden_dim * 2, hidden_dim * 4), self.make_disc_block(hidden_dim * 4, hidden_dim * 8), self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True), ) def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False): ''' Function to return a sequence of operations corresponding to a encoder block of the VAE, corresponding to a convolution, a batchnorm (except for in the last layer), and an activation Parameters: input_channels: how many channels the input feature representation has output_channels: how many channels the output feature representation should have kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) stride: the stride of the convolution final_layer: whether we're on the final layer (affects activation and batchnorm) ''' if not final_layer: return nn.Sequential( nn.Conv2d(input_channels, output_channels, kernel_size, stride), nn.BatchNorm2d(output_channels), nn.LeakyReLU(0.2, inplace=True), ) else: return nn.Sequential( nn.Conv2d(input_channels, output_channels, kernel_size, stride), ) def forward(self, image): ''' Function for completing a forward pass of the Encoder: Given an image tensor, returns a 1-dimension tensor representing fake/real. Parameters: image: a flattened image tensor with dimension (im_dim) ''' disc_pred = self.disc(image) encoding = disc_pred.view(len(disc_pred), -1) # The stddev output is treated as the log of the variance of the normal # distribution by convention and for numerical stability return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp() class Decoder(nn.Module): ''' Decoder Class Values: z_dim: the dimension of the noise vector, a scalar im_chan: the number of channels of the output image, a scalar hidden_dim: the inner dimension, a scalar ''' def __init__(self, z_dim=Z_DIM, im_chan=3, hidden_dim=DEC_HIDDEN_DIM): super(Decoder, self).__init__() self.z_dim = z_dim self.gen = nn.Sequential( self.make_gen_block(z_dim, hidden_dim * 16), self.make_gen_block(hidden_dim * 16, hidden_dim * 8, kernel_size=4, stride=1), self.make_gen_block(hidden_dim * 8, hidden_dim * 4), self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4), self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4), self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), ) def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False): ''' Function to return a sequence of operations corresponding to a Decoder block of the VAE, corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation Parameters: input_channels: how many channels the input feature representation has output_channels: how many channels the output feature representation should have kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) stride: the stride of the convolution final_layer: whether we're on the final layer (affects activation and batchnorm) ''' if not final_layer: return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True), ) else: return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), nn.Sigmoid(), ) def forward(self, noise): ''' Function for completing a forward pass of the Decoder: Given a noise vector, returns a generated image. Parameters: noise: a noise tensor with dimensions (batch_size, z_dim) ''' x = noise.view(len(noise), self.z_dim, 1, 1) return self.gen(x) class VAE(nn.Module): ''' VAE Class Values: z_dim: the dimension of the noise vector, a scalar im_chan: the number of channels of the output image, a scalar MNIST is black-and-white, so that's our default hidden_dim: the inner dimension, a scalar ''' def __init__(self, z_dim=Z_DIM, im_chan=3): super(VAE, self).__init__() self.z_dim = z_dim self.encode = Encoder(im_chan, z_dim) self.decode = Decoder(z_dim, im_chan) def forward(self, images): ''' Function for completing a forward pass of the Decoder: Given a noise vector, returns a generated image. Parameters: images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width) Returns: decoding: the autoencoded image q_dist: the z-distribution of the encoding ''' q_mean, q_stddev = self.encode(images) q_dist = Normal(q_mean, q_stddev) z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation decoding = self.decode(z_sample) return decoding, q_dist