Vemund Fredriksen
Add training pipeline (#21)
1cc0005 unverified
raw
history blame
No virus
6.98 kB
import numpy as np
import torch as T
from monai.data import (DataLoader, Dataset)
from monai.metrics import compute_meandice
from monai.transforms import (
AddChanneld,
Compose,
LoadImaged,
NormalizeIntensityd,
ToTensord,
Resized,
RandFlipd,
RandRotate90d,
ThresholdIntensityd,
ScaleIntensityRanged
)
from Engine.utils import (
load_model,
read_yaml
)
from Engine.utils import (
load_model,
read_yaml
)
from Engine.utils import (
load_model,
read_yaml
)
def init_plot_file(plot_path):
with open(plot_path, "w+") as file:
file.write("step,train_loss,val_loss,dice_score\n")
def append_metrics(plot_path, epoch, train_loss, val_loss, dice_score):
with open(plot_path, 'a') as file:
file.write(f"{epoch},{train_loss},{val_loss},{dice_score}\n")
def train(model, loss_function, optimizer, train_loader, val_loader, device, train_config):
best_val_loss = float('inf')
val_loss = float('inf')
total_steps = 0
dcs_metric = 0
step_loss = 0
optimizer.zero_grad()
for epoch in range(train_config['max_epochs']):
print("-" * 10)
print(f"epoch {epoch + 1}/{train_config['max_epochs']}")
model.train()
train_loss = 0
step = 0
accumulated_loss = 0
if((epoch + 1) % train_config['save_frequency'] == 0):
T.save(model.state_dict(), train_config['model_directory'] + f"model_{epoch + 1}.pth")
for batch_data in train_loader:
total_steps += 1
step += 1
inputs, boxes, labels = batch_data["image"].to(device), batch_data["boxes"].to(device), batch_data["label"].to(device)
loss = model.train_step(inputs, labels, loss_function, boxes)
loss.backward()
optimizer.step()
train_loss += loss.detach().item()
step_loss += loss.detach().item()
print(
f"step {step}/{len(train_loader.dataset) // train_loader.batch_size}, "
f"train_loss: {loss.item():.6f}")
epoch_len = len(train_loader.dataset) // train_loader.batch_size
accumulated_loss += loss
if(total_steps % train_config['batch_size'] == 0):
accumulated_loss = 0
step_loss /= train_config['batch_size']
append_metrics(train_config['metric_path'], total_steps, step_loss, val_loss, dcs_metric)
step_loss = 0
train_loss /= len(train_loader.dataset)
print(f"epoch {epoch + 1} average loss: {train_loss:.6f}")
if((epoch + 1) % train_config['val_frequency'] == 0):
model.eval()
print("Start eval")
step = 0
val_loss = 0
dcs_metric = 0
with T.no_grad():
for batch_data in val_loader:
step += 1
inputs, boxes, labels = batch_data["image"].to(device), batch_data["boxes"].to(device), batch_data["label"].to(device)
outputs = model(inputs)
loss = loss_function(outputs, labels)
outputs = outputs.cpu()
outputs[outputs >= train_config['output_threshold']] = 1
outputs[outputs < train_config['output_threshold']] = 0
outputs = outputs.to(device)
dcs_metric += T.mean(compute_meandice(outputs.unsqueeze(0), labels.unsqueeze(0))).item()
val_loss += loss.item()
dcs_metric /= len(val_loader.dataset)
val_loss /= step
if(val_loss < best_val_loss):
best_val_loss = val_loss
T.save(model.state_dict(), train_config['model_directory'] + f"model_best.pth")
print(f"epoch {epoch + 1} validation loss: {val_loss:.6f}")
print(f"epoch {epoch + 1} validation dice score: {dcs_metric:.6f}")
T.save(model.state_dict(), train_config['model_directory'] + f"model_last.pth")
def initiate(config_path):
config = read_yaml(config_path)
init_plot_file(config['train']['metric_path'])
data_paths = config["data"]["train_dataset"]
image_shape = (config["data"]["scale_dim"]["d_0"], config["data"]["scale_dim"]["d_1"], config["data"]["scale_dim"]["d_2"])
combined_train = []
combined_val = []
for data in data_paths:
prefixes = read_yaml(data)
for i, d in enumerate(read_yaml(data)['train']):
instance = {
'image' : prefixes['image_prefix'] + d['label'],
'label' : prefixes['label_prefix'] + d['label'],
'boxes' : prefixes['boxes_prefix'] + d['label']
}
combined_train.append(instance)
for data in data_paths:
prefixes = read_yaml(data)
for i, d in enumerate(read_yaml(data)['train']):
instance = {
'image' : prefixes['image_prefix'] + d['label'],
'label' : prefixes['label_prefix'] + d['label'],
'boxes' : prefixes['boxes_prefix'] + d['label']
}
combined_val.append(instance)
train_transform = Compose(
[
LoadImaged(keys=["image", "boxes", "label"]),
AddChanneld(keys=["image", "boxes", "label"]),
RandFlipd(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axis=0),
RandFlipd(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axis=1),
RandFlipd(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axis=2),
RandRotate90d(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axes=(0, 1)),
RandRotate90d(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axes=(0, 2)),
RandRotate90d(keys=["image", "boxes", "label"], prob=config["data"]["aug_prob"], spatial_axes=(1, 2)),
ToTensord(keys=["image", "boxes", "label"]),
]
)
val_transform = Compose(
[
LoadImaged(keys=["image", "boxes", "label"]),
AddChanneld(keys=["image", "boxes", "label"]),
ToTensord(keys=["image", "boxes", "label"]),
]
)
train_dataset = Dataset(combined_train, train_transform)
train_loader = T.utils.data.DataLoader(train_dataset, 1, shuffle = True)
val_dataset = Dataset(combined_val, val_transform)
val_loader = T.utils.data.DataLoader(val_dataset, 1)
device = T.device(config["device"])
model, loss, optimizer = load_model(config['model'])
if (config["model"]["weights"]):
model.load_state_dict(T.load(config["model"]["weights"]))
model.to(device)
print("initiates training!")
train(model, loss, optimizer, train_loader, val_loader, device, config['train'])