Vemund Fredriksen
Add training pipeline (#21)
1cc0005 unverified
raw
history blame
No virus
3.62 kB
import platform
import torch as T
import nibabel
import numpy as np
from monai.networks.nets.unet import UNet
from monai.data import DataLoader, Dataset
from monai.metrics import compute_meandice
from monai.transforms import (
AddChanneld,
Compose,
LoadImaged,
NormalizeIntensityd,
ToTensord,
Resized,
AsDiscrete,
ThresholdIntensityd,
KeepLargestConnectedComponent
)
from Engine.utils import (
read_yaml,
load_model,
store_output
)
def init_eval_file(directory):
with open(directory + "/evaluate.csv", "w+") as file:
file.write("epoch,train_loss,val_loss,dice_score\n")
def append_eval(eval_path, input_filename, segment_filename, val_loss, dice_score):
with open(eval_path, 'a') as file:
file.write(f"{input_filename},{segment_filename},{val_loss},{dice_score}\n")
def evaluate(model, loss_function, loader, device, evaluate_config):
model.eval()
losses = []
dices = []
transform = KeepLargestConnectedComponent([1], connectivity = 3)
with T.no_grad():
for image in loader:
inp, boxes, label = image["image"].to(device), image["boxes"].to(device), image["label"].to(device)
directory_split = '\\' if platform.system() == 'Windows' else '/'
output = model(inp)
original_filename = image['image_meta_dict']['filename_or_obj'][0]
segmentation_filename = f"{evaluate_config['save_directory']}{directory_split}seg_{original_filename.split(directory_split)[-1]}"
original_image = nibabel.load(original_filename)
loss = loss_function(output, label).item()
output = output.cpu()
output[output >= evaluate_config['output_threshold']] = 1
output[output < evaluate_config['output_threshold']] = 0
output = output.to(device)
output = transform(output)
if(evaluate_config['save_segmentations']):
store_output(output, original_image, segmentation_filename, image['image_meta_dict']['affine'].squeeze(0).numpy())
dice = compute_meandice(output, label).item()
append_eval(evaluate_config['save_directory']+"evaluate.csv", original_filename, segmentation_filename, loss, dice)
losses.append(loss)
dices.append(dice)
append_eval(evaluate_config['save_directory']+"evaluate.csv", 'Total', '', sum(losses)/len(losses), sum(dices)/len(dices))
def initiate(config_file):
config = read_yaml(config_file)
device = T.device(config["device"])
data = read_yaml(config["data"]["train_dataset"])
image_shape = (config["data"]["scale_dim"]["d_0"], config["data"]["scale_dim"]["d_1"], config["data"]["scale_dim"]["d_2"])
for i, d in enumerate(data['data']):
data['data'][i]['image'] = data['image_prefix'] + d['label']
data['data'][i]['boxes'] = data['boxes_prefix'] + d['label']
data['data'][i]['label'] = data['label_prefix'] + d['label']
transform = Compose(
[
LoadImaged(keys=["image", "boxes", "label"]),
AddChanneld(keys=["image", "boxes", "label"]),
ToTensord(keys=["image", "boxes", "label"]),
]
)
dataset = Dataset(data['data'], transform)
loader = T.utils.data.DataLoader(dataset, 1)
model, loss = load_model(config['model'], eval = True)
model.load_state_dict(T.load(config["model"]["weights"]))
model.to(device)
init_eval_file(config['evaluate']['save_directory'])
evaluate(model, loss, loader, device, config['evaluate'])