HuggingDavid commited on
Commit
d40ba1d
1 Parent(s): 0f7c18e

Upload with huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +11 -7
  2. model.pt +2 -2
app.py CHANGED
@@ -8,19 +8,23 @@ from dotenv import load_dotenv
8
  from torch import nn
9
  import torch.nn.functional as F
10
 
11
- class SimpleLenet(nn.Module):
12
  def __init__(self, args=None):
13
  super().__init__()
14
  self.conv1 = nn.Conv2d(1, 6, 5, padding=2) # -> 6 channels, 28x28
15
- self.pool = nn.MaxPool2d(2) # -> 6 channels, 14x14
16
- self.conv2 = nn.Conv2d(6, 120, 14) #-> 120 channels, 1x1
17
- self.fc1 = nn.Linear(120, 10)
18
- self.fc2 = nn.Linear(10, 10)
 
 
19
 
20
  def __call__(self, x):
21
  xx = F.relu(self.conv1(x))
22
- xx = F.relu(self.pool(xx))
23
  xx = F.relu(self.conv2(xx))
 
 
24
  xx = xx.flatten(1)
25
  xx = F.relu(self.fc1(xx))
26
  return self.fc2(xx)
@@ -30,7 +34,7 @@ load_dotenv()
30
  hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), "simple-mnist-flagging")
31
 
32
  def load_model():
33
- model = SimpleLenet()
34
  model.load_state_dict(torch.load('model.pt'))
35
  model.eval()
36
  return model
 
8
  from torch import nn
9
  import torch.nn.functional as F
10
 
11
+ class Lenet(nn.Module):
12
  def __init__(self, args=None):
13
  super().__init__()
14
  self.conv1 = nn.Conv2d(1, 6, 5, padding=2) # -> 6 channels, 28x28
15
+ self.pool1 = nn.MaxPool2d(2) # -> 6 channels, 14x14
16
+ self.conv2 = nn.Conv2d(6, 16, 5) #-> 16 images, 10x10
17
+ self.pool2 = nn.MaxPool2d(2) # -> 16 channels, 5x5
18
+ self.conv3 = nn.Conv2d(16, 120, 5) #-> 16 images, 1x1
19
+ self.fc1 = nn.Linear(120, 84)
20
+ self.fc2 = nn.Linear(84, 10)
21
 
22
  def __call__(self, x):
23
  xx = F.relu(self.conv1(x))
24
+ xx = F.relu(self.pool1(xx))
25
  xx = F.relu(self.conv2(xx))
26
+ xx = F.relu(self.pool2(xx))
27
+ xx = F.relu(self.conv3(xx))
28
  xx = xx.flatten(1)
29
  xx = F.relu(self.fc1(xx))
30
  return self.fc2(xx)
 
34
  hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), "simple-mnist-flagging")
35
 
36
  def load_model():
37
+ model = Lenet()
38
  model.load_state_dict(torch.load('model.pt'))
39
  model.eval()
40
  return model
model.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:922abc05756421cbdffc98dc27a3eef6abf476761170e77656e7bb5477a45b10
3
- size 573263
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72ef238f2653e7e2d135b3ecee92ecc02ba463f99dd1871dd9735692a60b3a10
3
+ size 249671