File size: 1,934 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import List
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, Sequential
from tensorflow.keras import layers


class NLayerDiscriminator(Model):
    """Defines a PatchGAN discriminator as in Pix2Pix
    --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
    """

    def __init__(self, input_channels: int = 3, filters: int = 64, n_layers: int = 3):
        super().__init__()

        kernel_size = 4
        self.sequence = [
            layers.Conv2D(filters, kernel_size=kernel_size, padding="same"),
            layers.LeakyReLU(alpha=0.2),
        ]

        filters_mult = 1
        for n in range(1, n_layers):
            filters_mult = min(2**n, 8)

            self.sequence += [
                layers.AveragePooling2D(pool_size=2),
                layers.Conv2D(
                    filters * filters_mult,
                    kernel_size=kernel_size,
                    strides=1,  # 2,
                    padding="same",
                    use_bias=False,
                ),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.2),
            ]

        filters_mult = min(2**n_layers, 8)
        self.sequence += [
            layers.Conv2D(
                filters * filters_mult,
                kernel_size=kernel_size,
                strides=1,
                padding="same",
                use_bias=False,
            ),
            layers.BatchNormalization(),
            layers.LeakyReLU(alpha=0.2),
        ]

        self.sequence += [
            layers.Conv2D(1, kernel_size=kernel_size, strides=1, padding="same")
        ]

        # self.main = Sequential(sequence)

    def call(self, inputs, training=True, mask=None):
        h = inputs
        for seq in self.sequence:
            h = seq(h)
        return h
        # return self.main(inputs)