irotem98 commited on
Commit
13a7b44
1 Parent(s): ea9c429

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -11
model.py CHANGED
@@ -52,17 +52,16 @@ class MobileVision(nn.Module):
52
  self.projection2 = nn.Linear(4096, 1536).to(DEVICE).to(DTYPE)
53
 
54
  def forward(self, x):
55
- with torch.no_grad():
56
- resized_img = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
57
- out1 = self.vision(resized_img)
58
- x = split_chessboard(x, 2)
59
- x = self.vision(x)
60
- x = merge_chessboard(x, 2)
61
- x = F.interpolate(x, size=(8, 8), mode='area')
62
- x = torch.cat([out1, x], dim=1)
63
-
64
- x = x.reshape(x.size(0), x.size(1), -1)
65
- x = x.permute(0, 2, 1)
66
 
67
  x = self.projection(x)
68
  x = self.projection2(x)
 
52
  self.projection2 = nn.Linear(4096, 1536).to(DEVICE).to(DTYPE)
53
 
54
  def forward(self, x):
55
+ resized_img = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
56
+ out1 = self.vision(resized_img)
57
+ x = split_chessboard(x, 2)
58
+ x = self.vision(x)
59
+ x = merge_chessboard(x, 2)
60
+ x = F.interpolate(x, size=(8, 8), mode='area')
61
+ x = torch.cat([out1, x], dim=1)
62
+
63
+ x = x.reshape(x.size(0), x.size(1), -1)
64
+ x = x.permute(0, 2, 1)
 
65
 
66
  x = self.projection(x)
67
  x = self.projection2(x)