File size: 6,500 Bytes
479c88d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn

from constants import *

"""
    Class for custom activation.
"""
class SymReLU(nn.Module):
    def __init__(self, inplace: bool = False):
        super().__init__()
        self.inplace = inplace

    def forward(self, input):
        return torch.min(torch.max(input, -torch.ones_like(input)), torch.ones_like(input))

    def extra_repr(self) -> str:
        inplace_str = 'inplace=True' if self.inplace else ''
        return inplace_str


"""
    Class implementing YOLO-Stamp architecture described in https://link.springer.com/article/10.1134/S1054661822040046.
"""
class YOLOStamp(nn.Module):
    def __init__(
            self,
            anchors=ANCHORS,
            in_channels=3,
    ):
        super().__init__()
        
        self.register_buffer('anchors', torch.tensor(anchors))

        self.act = SymReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm1 = nn.BatchNorm2d(num_features=8)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm2 = nn.BatchNorm2d(num_features=16)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm3 = nn.BatchNorm2d(num_features=16)
        self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm4 = nn.BatchNorm2d(num_features=16)
        self.conv5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm5 = nn.BatchNorm2d(num_features=16)
        self.conv6 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm6 = nn.BatchNorm2d(num_features=24)
        self.conv7 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm7 = nn.BatchNorm2d(num_features=24)
        self.conv8 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm8 = nn.BatchNorm2d(num_features=48)
        self.conv9 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm9 = nn.BatchNorm2d(num_features=48)
        self.conv10 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm10 = nn.BatchNorm2d(num_features=48)
        self.conv11 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.norm11 = nn.BatchNorm2d(num_features=64)
        self.conv12 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        self.norm12 = nn.BatchNorm2d(num_features=256)
        self.conv13 = nn.Conv2d(in_channels=256, out_channels=len(anchors) * 5, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
    
    def forward(self, x, head=True):
        x = x.type(self.conv1.weight.dtype)
        x = self.act(self.pool(self.norm1(self.conv1(x))))
        x = self.act(self.pool(self.norm2(self.conv2(x))))
        x = self.act(self.pool(self.norm3(self.conv3(x))))
        x = self.act(self.pool(self.norm4(self.conv4(x))))
        x = self.act(self.pool(self.norm5(self.conv5(x))))
        x = self.act(self.norm6(self.conv6(x)))
        x = self.act(self.norm7(self.conv7(x)))
        x = self.act(self.pool(self.norm8(self.conv8(x))))
        x = self.act(self.norm9(self.conv9(x)))
        x = self.act(self.norm10(self.conv10(x)))
        x = self.act(self.norm11(self.conv11(x)))
        x = self.act(self.norm12(self.conv12(x)))
        x = self.conv13(x)
        nb, _, nh, nw= x.shape
        x = x.permute(0, 2, 3, 1).view(nb, nh, nw, self.anchors.shape[0], 5)
        return x
    

class Encoder(torch.nn.Module):
    '''
    Encoder Class
    Values:
    im_chan: the number of channels of the output image, a scalar
    hidden_dim: the inner dimension, a scalar
    '''

    def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM):
        super(Encoder, self).__init__()
        self.z_dim = output_chan
        self.disc = torch.nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
            self.make_disc_block(hidden_dim * 4, hidden_dim * 8),
            self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a encoder block of the VAE, 
        corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
        Parameters:
        input_channels: how many channels the input feature representation has
        output_channels: how many channels the output feature representation should have
        kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
        stride: the stride of the convolution
        final_layer: whether we're on the final layer (affects activation and batchnorm)
        '''        
        if not final_layer:
            return torch.nn.Sequential(
                torch.nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                torch.nn.BatchNorm2d(output_channels),
                torch.nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return torch.nn.Sequential(
                torch.nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the Encoder: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
        image: a flattened image tensor with dimension (im_dim)
        '''
        disc_pred = self.disc(image)
        encoding = disc_pred.view(len(disc_pred), -1)
        # The stddev output is treated as the log of the variance of the normal 
        # distribution by convention and for numerical stability
        return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()