Vemund Fredriksen
Add training pipeline (#21)
1cc0005 unverified
raw
history blame
No virus
2.65 kB
import torch as T
import yaml
import nibabel
from monai.losses import DiceLoss, GeneralizedDiceLoss, FocalLoss, TverskyLoss
from Networks.UNet4 import UNet4
from Networks.UNet5_filter_teacher import UNet5_filter_teacher
from Networks.UNet5_decoder_teacher import UNet5_decoder_teacher
from Networks.UNet_test import UNet_student
from Networks.UNet_concat_double_student import UNet_concat_student
from monai.networks.nets.unet import UNet
from Networks.monai_student import UNet_double
from Networks.monai_unet import UNet_single
def load_model(model_config, infer = False, eval = False):
model, loss, optim = None, None, None
model_name = model_config["architecture"]
base = model_config["filter_base"]
expansion = model_config["filter_expansion"]
filter_layers = model_config["filters"]
if(model_name == 'UNet_decoder'):
model = UNet5_decoder_teacher(base, expansion)
elif(model_name == 'UNet_filter'):
model = UNet5_filter_teacher(base, expansion)
elif(model_name == 'UNet_con_double'):
model = UNet_concat_student(base, expansion)
elif(model_name == 'UNet_monai'):
model = UNet_single(3, 1, 1, tuple(filter_layers), tuple([2 for i in range(len(filter_layers) - 1)]))
elif(model_name == 'UNet_monai_double'):
model = UNet_double(3, 1, 1, tuple(filter_layers), tuple([2 for i in range(len(filter_layers) - 1)]))
elif(model_name == 'UNet4'):
model = UNet4(base, expansion)
else:
print("Architecture not found...")
exit(1)
if(infer):
return model
loss_name = model_config["loss"]
if(loss_name == 'BCE'):
loss = T.nn.BCELoss()
elif(loss_name == 'DiceLoss'):
loss = DiceLoss()
elif(loss_name == "GenDice"):
loss = GeneralizedDiceLoss()
elif(loss_name == "Tversky"):
loss = TverskyLoss(alpha=2.0, beta=10.0)
else:
print("Loss not found...")
exit(1)
if(eval):
return model, loss
optim_name = model_config['optimizer']['name']
if(optim_name == 'Adam'):
optim = T.optim.Adam(model.parameters(), model_config['optimizer']['lr'])
else:
print("Optimizer not found...")
exit(1)
return model, loss, optim
def read_yaml(yaml_path):
with open(yaml_path, 'r') as file:
config = yaml.load(file, Loader = yaml.FullLoader)
return config
def store_output(output, original_image, directory, affine):
headers = original_image.header
output = output.squeeze(0).squeeze(0)
save = nibabel.Nifti1Image(output.cpu().numpy(), affine, headers)
nibabel.save(save, directory)