File size: 3,550 Bytes
1cc0005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch as T
import yaml
import nibabel
import sys
from monai.losses import DiceLoss

from monai.networks.nets.unet import UNet

import numpy as np
import torch as T
from monai.data import (DataLoader, Dataset)
from monai.metrics import compute_meandice
from monai.transforms import (
    AddChanneld,
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    ToTensord,
    Resized,
    RandFlipd,
    RandRotate90d,
    ThresholdIntensityd
)

from Engine.utils import (
    load_model,
    read_yaml
)

from Networks.UNet4 import UNet4

def init_plot_file(plot_path):
    with open(plot_path, "w+") as file:
        file.write("step,loss\n")

def append_metrics(plot_path, epoch, loss):
    with open(plot_path, 'a') as file:
        file.write(f"{epoch},{loss}\n")

def main():
    path = "/cluster/work/sosevle/LungTumorSegmentation/Resources/Idun_Train_1/full_size.yaml"
    plot_path = "/cluster/work/sosevle/metrics/t_test3.csv"
    model_loc = "/cluster/work/sosevle/models/t_test3/"
    dim = (128,128,128)
    if sys.platform == "win32":
        path = "D:\\Repos\\LungTumorSegmentation\\Resources\\Test\\data_test_train.yaml"
        plot_path = "D:\\Repos\\LungTumorSegmentation\\m_test\\test.csv"
        model_loc = "D:\\Repos\\LungTumorSegmentation\\m_test\\"
        dim = (32,32,32)

    data = read_yaml(path)

    for i, d in enumerate(data['train']):
        data['train'][i]['image'] = data['image_prefix'] + d['image']
        data['train'][i]['boxes'] = data['boxes_prefix'] + d['boxes']
        data['train'][i]['label'] = data['label_prefix'] + d['label']

    train_transform = Compose(
        [
            LoadImaged(keys=["image", "boxes", "label"]),
            ThresholdIntensityd(keys=["image"], above=False, threshold=1024, cval=1024),
            NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
            AddChanneld(keys=["image", "boxes", "label"]),
            Resized(keys=["image", "boxes", "label"], spatial_size=dim),
            RandFlipd(keys=["image", "boxes", "label"], prob=0.1, spatial_axis=0),
            RandFlipd(keys=["image", "boxes", "label"], prob=0.1, spatial_axis=1),
            RandFlipd(keys=["image", "boxes", "label"], prob=0.1, spatial_axis=2),
            RandRotate90d(keys=["image", "boxes", "label"], prob=0.1, spatial_axes=(0, 1)),
            ToTensord(keys=["image", "boxes", "label"]),
        ]
    )

    train_dataset = Dataset(data['train'], train_transform)
    train_loader = T.utils.data.DataLoader(train_dataset, 1, shuffle=True)

    device = T.device("cuda")

    #model = T.nn.Sequential(UNet(3, 1, 1, (32, 64, 128, 256), (2, 2, 2)), T.nn.Sigmoid())
    model = UNet4(32, 2)

    loss_criterion = DiceLoss()
    optimizer = T.optim.Adam(model.parameters(), 0.0005)

    model.to(device)
    model.train()

    init_plot_file(plot_path)

    optimizer.zero_grad()

    for i in range(500):
        print(f"Epoch {i}")
        epoch_loss = 0

        for j in train_loader:
            inputs, boxes, labels = j["image"].to(device), j["boxes"].to(device), j["label"].to(device)
            forwarded = model.forward(inputs)
            train_loss = loss_criterion(forwarded, labels)
            train_loss.backward()
            optimizer.step()

            epoch_loss += train_loss.item()
            print(f"loss: {train_loss.item()}")

        append_metrics(plot_path, i, epoch_loss / len(train_loader))

        if (i % 10 == 2):
            T.save(model.state_dict(), model_loc + f"m_{i}.pth")


if __name__ == "__main__":
    main()