File size: 3,727 Bytes
1cc0005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch as T
import torch.nn as nn
from Networks.network import Network

class UNet5_Teacher(Network):
    def __init__(self, base, expansion):
        super(UNet5_Teacher, self).__init__()
        self._build(base, expansion)

    def forward(self, image, segbox):
        x_new = T.cat((image, segbox), dim = 1)
        layer_0, layer_1, layer_2, layer_3 = self._analysis(x_new)

        return self._synthesis(layer_0, layer_1, layer_2, layer_3, self._bridge(layer_3))

    def train_step(self, image, segment, criterion, segbox = None):
        output = self.forward(image, segbox)

        loss = criterion(output, segment)
        return loss

    def _analysis(self, x):
        layer_0 = self.analysis_0(x)
        layer_1 = self.analysis_1(layer_0)
        layer_2 = self.analysis_2(layer_1)
        layer_3 = self.analysis_3(layer_2)

        return layer_0, layer_1, layer_2, layer_3

    def _bridge(self, layer_3):
        return self.bridge(layer_3)

    def _synthesis(self, l0, l1, l2, l3, l4):
        c_3 = T.cat((l3, l4), dim = 1)
        c_2 = T.cat((l2, self.synthesis_3(c_3)), dim = 1)
        c_1 = T.cat((l1, self.synthesis_2(c_2)), dim = 1)
        c_0 = T.cat((l0, self.synthesis_1(c_1)), dim = 1)

        return self.synthesis_0(c_0)

    def _build(self, base, expansion):
        fl_0 = int(base)
        fl_1 = int(base * expansion)
        fl_2 = int(base * (expansion ** 2))
        fl_3 = int(base * (expansion ** 3))
        fl_4 = int(base * (expansion ** 4))

        self.analysis_0 = nn.Sequential(
            nn.Conv3d(2, fl_0, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_0, fl_0, 3, 1, 1),
            nn.LeakyReLU(),
        )

        self.analysis_1 = nn.Sequential(
            nn.MaxPool3d(2),
            nn.Conv3d(fl_0, fl_1, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_1, fl_1, 3, 1, 1),
            nn.LeakyReLU(),
        )

        self.analysis_2 = nn.Sequential(
            nn.MaxPool3d(2),
            nn.Conv3d(fl_1, fl_2, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_2, fl_2, 3, 1, 1),
            nn.LeakyReLU(),
        )

        self.analysis_3 = nn.Sequential(
            nn.MaxPool3d(2),
            nn.Conv3d(fl_2, fl_3, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_3, fl_3, 3, 1, 1),
            nn.LeakyReLU(),
        )

        self.bridge = nn.Sequential(
            nn.MaxPool3d(2),
            nn.Conv3d(fl_3, fl_3, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_3, fl_3, 3, 1, 1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(fl_3, fl_3, 2, 2, 0),
        )

        self.synthesis_3 = nn.Sequential(
            nn.Conv3d(fl_3 + fl_3, fl_2, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_2, fl_2, 3, 1, 1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(fl_2, fl_2, 2, 2, 0),
        )

        self.synthesis_2 = nn.Sequential(
            nn.Conv3d(fl_2 + fl_2, fl_1, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_1, fl_1, 3, 1, 1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(fl_1, fl_1, 2, 2, 0),
        )

        self.synthesis_1 = nn.Sequential(
            nn.Conv3d(fl_1 + fl_1, fl_0, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_0, fl_0, 3, 1, 1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(fl_0, fl_0, 2, 2, 0),
        )

        self.synthesis_0 = nn.Sequential(
            nn.Conv3d(fl_0 + fl_0, fl_0, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_0, fl_0, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv3d(fl_0, 1, 3, 1, 1),
            nn.Conv3d(1, 1, 1, 1, 0),
            nn.Sigmoid()
        )