MNIST-classifier / train.py
danielcd99's picture
Added requirements.txt
b4b5fb6
raw
history blame
No virus
2.61 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
if torch.cuda.is_available():
device = torch.device("cuda:0")
print("GPU")
else:
device = torch.device("cpu")
print("CPU")
# MNIST dataset
batch_size=64
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# NEURAL NETWORK
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)),
nn.Tanh(),
nn.AvgPool2d(2, 2),
nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)),
nn.Tanh(),
nn.AvgPool2d(2, 2)
)
self.linear = nn.Sequential(
nn.Linear(4*4*12,10)
)
def forward(self, x):
x = self.convs(x)
x = torch.flatten(x, 1)
return self.linear(x)
# TRAIN PARAMETERS
criterion = nn.CrossEntropyLoss()
model_adam = LeNet().to(device)
optimizer = torch.optim.Adam(model_adam.parameters(), lr=0.05)
n_steps = len(train_loader)
num_epochs = 10
# TRAIN
def train(model):
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# SAVING MODEL
torch.save(model_adam.state_dict(), "model_mnist.pth")