Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
No virus
4.65 kB
from typing import List
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow_addons.layers import GroupNormalization
from .layers import ResnetBlock, AttentionBlock, Downsample
# @tf.keras.utils.register_keras_serializable()
class Encoder(layers.Layer):
def __init__(
self,
*,
channels: int,
channels_multiplier: List[int],
num_res_blocks: int,
attention_resolution: List[int],
resolution: int,
z_channels: int,
dropout: float,
**kwargs
):
"""Encode an image into a latent vector. The encoder will be constitued of multiple levels (lenght of `channels_multiplier`) with for each level `num_res_blocks` ResnetBlock.
Args:
channels (int, optional): The number of channel for the first layer. Defaults to 128.
channels_multiplier (List[int], optional): The channel multiplier for each level (previous level channels X multipler). Defaults to [1, 1, 2, 2].
num_res_blocks (int, optional): Number of ResnetBlock at each level. Defaults to 1.
attention_resolution (List[int], optional): Add an attention block if the current resolution is in this array. Defaults to [16].
resolution (int, optional): The starting resolution. Defaults to 64.
z_channels (int, optional): The number of channel at the end of the encoder. Defaults to 128.
dropout (float, optional): The dropout ratio for each ResnetBlock. Defaults to 0.0.
"""
super().__init__(**kwargs)
self.channels = channels
self.channels_multiplier = channels_multiplier
self.num_resolutions = len(channels_multiplier)
self.num_res_blocks = num_res_blocks
self.attention_resolution = attention_resolution
self.resolution = resolution
self.z_channels = z_channels
self.dropout = dropout
self.conv_in = layers.Conv2D(
self.channels, kernel_size=3, strides=1, padding="same"
)
current_resolution = resolution
in_channels_multiplier = (1,) + tuple(channels_multiplier)
self.downsampling_list = []
for i_level in range(self.num_resolutions):
block_in = channels * in_channels_multiplier[i_level]
block_out = channels * channels_multiplier[i_level]
for i_block in range(self.num_res_blocks):
self.downsampling_list.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
)
)
block_in = block_out
if current_resolution in attention_resolution:
self.downsampling_list.append(AttentionBlock(block_in))
if i_level != self.num_resolutions - 1:
self.downsampling_list.append(Downsample(block_in))
current_resolution = current_resolution // 2
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
)
self.mid["attn_1"] = AttentionBlock(block_in)
self.mid["block_2"] = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
)
# end
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
self.conv_out = layers.Conv2D(
z_channels,
kernel_size=3,
strides=1,
padding="same",
)
# def get_config(self):
# config = super().get_config()
# config.update(
# {
# "channels": self.channels,
# "channels_multiplier": self.channels_multiplier,
# "num_res_blocks": self.num_res_blocks,
# "attention_resolution": self.attention_resolution,
# "resolution": self.resolution,
# "z_channels": self.z_channels,
# "dropout": self.dropout,
# }
# )
# return config
def call(self, inputs, training=True, mask=None):
h = self.conv_in(inputs)
for downsampling in self.downsampling_list:
h = downsampling(h)
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# end
h = self.norm_out(h)
h = keras.activations.swish(h)
h = self.conv_out(h)
return h