import torch as T import yaml import nibabel from monai.losses import DiceLoss, GeneralizedDiceLoss, FocalLoss, TverskyLoss from Networks.UNet4 import UNet4 from Networks.UNet5_filter_teacher import UNet5_filter_teacher from Networks.UNet5_decoder_teacher import UNet5_decoder_teacher from Networks.UNet_test import UNet_student from Networks.UNet_concat_double_student import UNet_concat_student from monai.networks.nets.unet import UNet from Networks.monai_student import UNet_double from Networks.monai_unet import UNet_single def load_model(model_config, infer = False, eval = False): model, loss, optim = None, None, None model_name = model_config["architecture"] base = model_config["filter_base"] expansion = model_config["filter_expansion"] filter_layers = model_config["filters"] if(model_name == 'UNet_decoder'): model = UNet5_decoder_teacher(base, expansion) elif(model_name == 'UNet_filter'): model = UNet5_filter_teacher(base, expansion) elif(model_name == 'UNet_con_double'): model = UNet_concat_student(base, expansion) elif(model_name == 'UNet_monai'): model = UNet_single(3, 1, 1, tuple(filter_layers), tuple([2 for i in range(len(filter_layers) - 1)])) elif(model_name == 'UNet_monai_double'): model = UNet_double(3, 1, 1, tuple(filter_layers), tuple([2 for i in range(len(filter_layers) - 1)])) elif(model_name == 'UNet4'): model = UNet4(base, expansion) else: print("Architecture not found...") exit(1) if(infer): return model loss_name = model_config["loss"] if(loss_name == 'BCE'): loss = T.nn.BCELoss() elif(loss_name == 'DiceLoss'): loss = DiceLoss() elif(loss_name == "GenDice"): loss = GeneralizedDiceLoss() elif(loss_name == "Tversky"): loss = TverskyLoss(alpha=2.0, beta=10.0) else: print("Loss not found...") exit(1) if(eval): return model, loss optim_name = model_config['optimizer']['name'] if(optim_name == 'Adam'): optim = T.optim.Adam(model.parameters(), model_config['optimizer']['lr']) else: print("Optimizer not found...") exit(1) return model, loss, optim def read_yaml(yaml_path): with open(yaml_path, 'r') as file: config = yaml.load(file, Loader = yaml.FullLoader) return config def store_output(output, original_image, directory, affine): headers = original_image.header output = output.squeeze(0).squeeze(0) save = nibabel.Nifti1Image(output.cpu().numpy(), affine, headers) nibabel.save(save, directory)