File size: 4,652 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import List
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow_addons.layers import GroupNormalization
from .layers import ResnetBlock, AttentionBlock, Downsample


# @tf.keras.utils.register_keras_serializable()
class Encoder(layers.Layer):
    def __init__(
        self,
        *,
        channels: int,
        channels_multiplier: List[int],
        num_res_blocks: int,
        attention_resolution: List[int],
        resolution: int,
        z_channels: int,
        dropout: float,
        **kwargs
    ):
        """Encode an image into a latent vector. The encoder will be constitued of multiple levels (lenght of `channels_multiplier`) with for each level `num_res_blocks` ResnetBlock.
        Args:
            channels (int, optional): The number of channel for the first layer. Defaults to 128.
            channels_multiplier (List[int], optional): The channel multiplier for each level (previous level channels X multipler). Defaults to [1, 1, 2, 2].
            num_res_blocks (int, optional): Number of ResnetBlock at each level. Defaults to 1.
            attention_resolution (List[int], optional): Add an attention block if the current resolution is in this array. Defaults to [16].
            resolution (int, optional): The starting resolution. Defaults to 64.
            z_channels (int, optional): The number of channel at the end of the encoder. Defaults to 128.
            dropout (float, optional): The dropout ratio for each ResnetBlock. Defaults to 0.0.
        """
        super().__init__(**kwargs)

        self.channels = channels
        self.channels_multiplier = channels_multiplier
        self.num_resolutions = len(channels_multiplier)
        self.num_res_blocks = num_res_blocks
        self.attention_resolution = attention_resolution
        self.resolution = resolution
        self.z_channels = z_channels
        self.dropout = dropout

        self.conv_in = layers.Conv2D(
            self.channels, kernel_size=3, strides=1, padding="same"
        )

        current_resolution = resolution

        in_channels_multiplier = (1,) + tuple(channels_multiplier)

        self.downsampling_list = []

        for i_level in range(self.num_resolutions):
            block_in = channels * in_channels_multiplier[i_level]
            block_out = channels * channels_multiplier[i_level]
            for i_block in range(self.num_res_blocks):
                self.downsampling_list.append(
                    ResnetBlock(
                        in_channels=block_in,
                        out_channels=block_out,
                        dropout=dropout,
                    )
                )
                block_in = block_out

                if current_resolution in attention_resolution:
                    self.downsampling_list.append(AttentionBlock(block_in))

            if i_level != self.num_resolutions - 1:
                self.downsampling_list.append(Downsample(block_in))
                current_resolution = current_resolution // 2

        # middle
        self.mid = {}
        self.mid["block_1"] = ResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            dropout=dropout,
        )
        self.mid["attn_1"] = AttentionBlock(block_in)
        self.mid["block_2"] = ResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            dropout=dropout,
        )

        # end
        self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
        self.conv_out = layers.Conv2D(
            z_channels,
            kernel_size=3,
            strides=1,
            padding="same",
        )

    # def get_config(self):
    #     config = super().get_config()
    #     config.update(
    #         {
    #             "channels": self.channels,
    #             "channels_multiplier": self.channels_multiplier,
    #             "num_res_blocks": self.num_res_blocks,
    #             "attention_resolution": self.attention_resolution,
    #             "resolution": self.resolution,
    #             "z_channels": self.z_channels,
    #             "dropout": self.dropout,
    #         }
    #     )
    #     return config

    def call(self, inputs, training=True, mask=None):
        h = self.conv_in(inputs)
        for downsampling in self.downsampling_list:
            h = downsampling(h)

        h = self.mid["block_1"](h)
        h = self.mid["attn_1"](h)
        h = self.mid["block_2"](h)

        # end
        h = self.norm_out(h)
        h = keras.activations.swish(h)
        h = self.conv_out(h)
        return h