Spaces:
Sleeping
Sleeping
import torch as T | |
import yaml | |
import nibabel | |
import sys | |
from monai.losses import DiceLoss | |
from monai.networks.nets.unet import UNet | |
import numpy as np | |
import torch as T | |
from monai.data import (DataLoader, Dataset) | |
from monai.metrics import compute_meandice | |
from monai.transforms import ( | |
AddChanneld, | |
Compose, | |
LoadImaged, | |
NormalizeIntensityd, | |
ToTensord, | |
Resized, | |
RandFlipd, | |
RandRotate90d, | |
ThresholdIntensityd | |
) | |
from Engine.utils import ( | |
load_model, | |
read_yaml | |
) | |
from Networks.UNet4 import UNet4 | |
def init_plot_file(plot_path): | |
with open(plot_path, "w+") as file: | |
file.write("step,loss\n") | |
def append_metrics(plot_path, epoch, loss): | |
with open(plot_path, 'a') as file: | |
file.write(f"{epoch},{loss}\n") | |
def main(): | |
path = "/cluster/work/sosevle/LungTumorSegmentation/Resources/Idun_Train_1/full_size.yaml" | |
plot_path = "/cluster/work/sosevle/metrics/t_test3.csv" | |
model_loc = "/cluster/work/sosevle/models/t_test3/" | |
dim = (128,128,128) | |
if sys.platform == "win32": | |
path = "D:\\Repos\\LungTumorSegmentation\\Resources\\Test\\data_test_train.yaml" | |
plot_path = "D:\\Repos\\LungTumorSegmentation\\m_test\\test.csv" | |
model_loc = "D:\\Repos\\LungTumorSegmentation\\m_test\\" | |
dim = (32,32,32) | |
data = read_yaml(path) | |
for i, d in enumerate(data['train']): | |
data['train'][i]['image'] = data['image_prefix'] + d['image'] | |
data['train'][i]['boxes'] = data['boxes_prefix'] + d['boxes'] | |
data['train'][i]['label'] = data['label_prefix'] + d['label'] | |
train_transform = Compose( | |
[ | |
LoadImaged(keys=["image", "boxes", "label"]), | |
ThresholdIntensityd(keys=["image"], above=False, threshold=1024, cval=1024), | |
NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True), | |
AddChanneld(keys=["image", "boxes", "label"]), | |
Resized(keys=["image", "boxes", "label"], spatial_size=dim), | |
RandFlipd(keys=["image", "boxes", "label"], prob=0.1, spatial_axis=0), | |
RandFlipd(keys=["image", "boxes", "label"], prob=0.1, spatial_axis=1), | |
RandFlipd(keys=["image", "boxes", "label"], prob=0.1, spatial_axis=2), | |
RandRotate90d(keys=["image", "boxes", "label"], prob=0.1, spatial_axes=(0, 1)), | |
ToTensord(keys=["image", "boxes", "label"]), | |
] | |
) | |
train_dataset = Dataset(data['train'], train_transform) | |
train_loader = T.utils.data.DataLoader(train_dataset, 1, shuffle=True) | |
device = T.device("cuda") | |
#model = T.nn.Sequential(UNet(3, 1, 1, (32, 64, 128, 256), (2, 2, 2)), T.nn.Sigmoid()) | |
model = UNet4(32, 2) | |
loss_criterion = DiceLoss() | |
optimizer = T.optim.Adam(model.parameters(), 0.0005) | |
model.to(device) | |
model.train() | |
init_plot_file(plot_path) | |
optimizer.zero_grad() | |
for i in range(500): | |
print(f"Epoch {i}") | |
epoch_loss = 0 | |
for j in train_loader: | |
inputs, boxes, labels = j["image"].to(device), j["boxes"].to(device), j["label"].to(device) | |
forwarded = model.forward(inputs) | |
train_loss = loss_criterion(forwarded, labels) | |
train_loss.backward() | |
optimizer.step() | |
epoch_loss += train_loss.item() | |
print(f"loss: {train_loss.item()}") | |
append_metrics(plot_path, i, epoch_loss / len(train_loader)) | |
if (i % 10 == 2): | |
T.save(model.state_dict(), model_loc + f"m_{i}.pth") | |
if __name__ == "__main__": | |
main() |