sadjava's picture
changed to pipelines
fd52b7f
raw
history blame contribute delete
No virus
3.58 kB
import copy
import csv
import os
import time
import numpy as np
import torch
from tqdm import tqdm
def train_model(model, criterion, dataloaders, optimizer, metrics, bpath,
num_epochs):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e10
# Use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Initialize the log file for training and testing loss and metrics
fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \
[f'Train_{m}' for m in metrics.keys()] + \
[f'Test_{m}' for m in metrics.keys()]
with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for epoch in range(1, num_epochs + 1):
print('Epoch {}/{}'.format(epoch, num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
# Initialize batch summary
batchsummary = {a: [0] for a in fieldnames}
for phase in ['Train', 'Test']:
if phase == 'Train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
# Iterate over data.
for sample in tqdm(iter(dataloaders[phase])):
inputs = sample['image'].to(device)
masks = sample['mask'].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# track history if only in train
with torch.set_grad_enabled(phase == 'Train'):
outputs = model(inputs)
loss = criterion(outputs['out'], masks)
y_pred = outputs['out'].data.cpu().numpy().ravel()
y_true = masks.data.cpu().numpy().ravel()
for name, metric in metrics.items():
if name == 'f1_score':
# Use a classification threshold of 0.1
batchsummary[f'{phase}_{name}'].append(
metric(y_true > 0, y_pred > 0.1))
else:
batchsummary[f'{phase}_{name}'].append(
metric(y_true.astype('uint8'), y_pred))
# backward + optimize only if in training phase
if phase == 'Train':
loss.backward()
optimizer.step()
batchsummary['epoch'] = epoch
epoch_loss = loss
batchsummary[f'{phase}_loss'] = epoch_loss.item()
print('{} Loss: {:.4f}'.format(phase, loss))
for field in fieldnames[3:]:
batchsummary[field] = np.mean(batchsummary[field])
print(batchsummary)
with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writerow(batchsummary)
# deep copy the model
if phase == 'Test' and loss < best_loss:
best_loss = loss
best_model_wts = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Lowest Loss: {:4f}'.format(best_loss))
# load best model weights
model.load_state_dict(best_model_wts)
return model