File size: 2,649 Bytes
1cc0005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch as T
import yaml
import nibabel

from monai.losses import DiceLoss, GeneralizedDiceLoss, FocalLoss, TverskyLoss
from Networks.UNet4 import UNet4
from Networks.UNet5_filter_teacher import UNet5_filter_teacher
from Networks.UNet5_decoder_teacher import UNet5_decoder_teacher
from Networks.UNet_test import UNet_student
from Networks.UNet_concat_double_student import UNet_concat_student
from monai.networks.nets.unet import UNet
from Networks.monai_student import UNet_double
from Networks.monai_unet import UNet_single


def load_model(model_config, infer = False, eval = False):
    model, loss, optim = None, None, None
    model_name = model_config["architecture"]
    base = model_config["filter_base"]
    expansion = model_config["filter_expansion"]
    
    filter_layers = model_config["filters"]

    if(model_name == 'UNet_decoder'):
        model = UNet5_decoder_teacher(base, expansion)
    elif(model_name == 'UNet_filter'):
        model = UNet5_filter_teacher(base, expansion)
    elif(model_name == 'UNet_con_double'):
        model = UNet_concat_student(base, expansion)
    elif(model_name == 'UNet_monai'):
        model = UNet_single(3, 1, 1, tuple(filter_layers), tuple([2 for i in range(len(filter_layers) - 1)]))
    elif(model_name == 'UNet_monai_double'):
        model = UNet_double(3, 1, 1, tuple(filter_layers), tuple([2 for i in range(len(filter_layers) - 1)]))
    elif(model_name == 'UNet4'):
        model = UNet4(base, expansion)
    else:
        print("Architecture not found...")
        exit(1)

    if(infer):
        return model

    loss_name = model_config["loss"]
    if(loss_name == 'BCE'):
        loss = T.nn.BCELoss()
    elif(loss_name == 'DiceLoss'):
        loss = DiceLoss()
    elif(loss_name == "GenDice"):
        loss = GeneralizedDiceLoss()
    elif(loss_name == "Tversky"):
        loss = TverskyLoss(alpha=2.0, beta=10.0)
    else:
        print("Loss not found...")
        exit(1)

    if(eval):
        return model, loss

    optim_name = model_config['optimizer']['name']
    if(optim_name == 'Adam'):
        optim = T.optim.Adam(model.parameters(), model_config['optimizer']['lr'])
    else:
        print("Optimizer not found...")
        exit(1)

    return model, loss, optim

def read_yaml(yaml_path):
    with open(yaml_path, 'r') as file:
        config = yaml.load(file, Loader = yaml.FullLoader)
    return config

def store_output(output, original_image, directory, affine):
    
    headers = original_image.header
    output = output.squeeze(0).squeeze(0)
    save = nibabel.Nifti1Image(output.cpu().numpy(), affine, headers)

    nibabel.save(save, directory)