File size: 3,624 Bytes
1cc0005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import platform
import torch as T
import nibabel
import numpy as np

from monai.networks.nets.unet import UNet
from monai.data import DataLoader, Dataset
from monai.metrics import compute_meandice
from monai.transforms import (
    AddChanneld,
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    ToTensord,
    Resized,
    AsDiscrete,
    ThresholdIntensityd,
    KeepLargestConnectedComponent
)
from Engine.utils import (
    read_yaml,
    load_model,
    store_output
)

def init_eval_file(directory):
    with open(directory + "/evaluate.csv", "w+") as file:
        file.write("epoch,train_loss,val_loss,dice_score\n")

def append_eval(eval_path, input_filename, segment_filename, val_loss, dice_score):
    with open(eval_path, 'a') as file:
        file.write(f"{input_filename},{segment_filename},{val_loss},{dice_score}\n")
    
def evaluate(model, loss_function, loader, device, evaluate_config):

    model.eval()
    losses = []
    dices = []

    transform = KeepLargestConnectedComponent([1], connectivity = 3)


    with T.no_grad():
        for image in loader:
            inp, boxes, label = image["image"].to(device), image["boxes"].to(device), image["label"].to(device)
            directory_split = '\\' if platform.system() == 'Windows' else '/'
            output = model(inp)
        
            original_filename = image['image_meta_dict']['filename_or_obj'][0]
            segmentation_filename = f"{evaluate_config['save_directory']}{directory_split}seg_{original_filename.split(directory_split)[-1]}"

            original_image = nibabel.load(original_filename)
            
            loss = loss_function(output, label).item()
            
            output = output.cpu()
            output[output >= evaluate_config['output_threshold']] = 1
            output[output < evaluate_config['output_threshold']] = 0
            output = output.to(device)

            output = transform(output)

            if(evaluate_config['save_segmentations']):
                store_output(output, original_image, segmentation_filename, image['image_meta_dict']['affine'].squeeze(0).numpy())

            dice = compute_meandice(output, label).item()


            append_eval(evaluate_config['save_directory']+"evaluate.csv", original_filename, segmentation_filename, loss, dice)
            losses.append(loss)
            dices.append(dice)

    append_eval(evaluate_config['save_directory']+"evaluate.csv", 'Total', '', sum(losses)/len(losses), sum(dices)/len(dices))

def initiate(config_file):
    config = read_yaml(config_file)
    device = T.device(config["device"])

    data = read_yaml(config["data"]["train_dataset"])

    image_shape = (config["data"]["scale_dim"]["d_0"], config["data"]["scale_dim"]["d_1"], config["data"]["scale_dim"]["d_2"])

    for i, d in enumerate(data['data']):
        data['data'][i]['image'] = data['image_prefix'] + d['label']
        data['data'][i]['boxes'] = data['boxes_prefix'] + d['label']
        data['data'][i]['label'] = data['label_prefix'] + d['label']

    transform = Compose(
        [
            LoadImaged(keys=["image", "boxes", "label"]),
            AddChanneld(keys=["image", "boxes", "label"]),
            ToTensord(keys=["image", "boxes", "label"]),
        ]
    )

    dataset = Dataset(data['data'], transform)
    loader = T.utils.data.DataLoader(dataset, 1)

    model, loss = load_model(config['model'], eval = True)
    model.load_state_dict(T.load(config["model"]["weights"]))
    model.to(device)

    init_eval_file(config['evaluate']['save_directory'])

    evaluate(model, loss, loader, device, config['evaluate'])