Update model.py
Browse files
model.py
CHANGED
@@ -52,17 +52,16 @@ class MobileVision(nn.Module):
|
|
52 |
self.projection2 = nn.Linear(4096, 1536).to(DEVICE).to(DTYPE)
|
53 |
|
54 |
def forward(self, x):
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
x = x.permute(0, 2, 1)
|
66 |
|
67 |
x = self.projection(x)
|
68 |
x = self.projection2(x)
|
|
|
52 |
self.projection2 = nn.Linear(4096, 1536).to(DEVICE).to(DTYPE)
|
53 |
|
54 |
def forward(self, x):
|
55 |
+
resized_img = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
|
56 |
+
out1 = self.vision(resized_img)
|
57 |
+
x = split_chessboard(x, 2)
|
58 |
+
x = self.vision(x)
|
59 |
+
x = merge_chessboard(x, 2)
|
60 |
+
x = F.interpolate(x, size=(8, 8), mode='area')
|
61 |
+
x = torch.cat([out1, x], dim=1)
|
62 |
+
|
63 |
+
x = x.reshape(x.size(0), x.size(1), -1)
|
64 |
+
x = x.permute(0, 2, 1)
|
|
|
65 |
|
66 |
x = self.projection(x)
|
67 |
x = self.projection2(x)
|