File size: 1,934 Bytes
3be620b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from typing import List
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, Sequential
from tensorflow.keras import layers
class NLayerDiscriminator(Model):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_channels: int = 3, filters: int = 64, n_layers: int = 3):
super().__init__()
kernel_size = 4
self.sequence = [
layers.Conv2D(filters, kernel_size=kernel_size, padding="same"),
layers.LeakyReLU(alpha=0.2),
]
filters_mult = 1
for n in range(1, n_layers):
filters_mult = min(2**n, 8)
self.sequence += [
layers.AveragePooling2D(pool_size=2),
layers.Conv2D(
filters * filters_mult,
kernel_size=kernel_size,
strides=1, # 2,
padding="same",
use_bias=False,
),
layers.BatchNormalization(),
layers.LeakyReLU(alpha=0.2),
]
filters_mult = min(2**n_layers, 8)
self.sequence += [
layers.Conv2D(
filters * filters_mult,
kernel_size=kernel_size,
strides=1,
padding="same",
use_bias=False,
),
layers.BatchNormalization(),
layers.LeakyReLU(alpha=0.2),
]
self.sequence += [
layers.Conv2D(1, kernel_size=kernel_size, strides=1, padding="same")
]
# self.main = Sequential(sequence)
def call(self, inputs, training=True, mask=None):
h = inputs
for seq in self.sequence:
h = seq(h)
return h
# return self.main(inputs)
|