In [2]:

import matplotlib.pyplot as plt
import numpy as np
from monai.config import print_config
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.transforms import MapTransform
from monai.data import DataLoader, Dataset
from monai.utils import set_determinism
from monai import transforms
import torch

print_config()

MONAI version: 1.4.dev2409
Numpy version: 1.26.2
Pytorch version: 1.13.0+cu116
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46c1b228091283fba829280a5d747f4237f76ed0
MONAI __file__: /usr/local/lib/python3.9/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.11.4
Pillow version: 10.1.0
Tensorboard version: 2.16.2
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.14.0+cu116
tqdm version: 4.66.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.2.1
einops version: 0.7.0
transformers version: 4.35.2
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the option

In [3]:
set_determinism(seed=0)

In [9]:
import os

parent_folder_path = '/app/brats_2021_task1/BraTS2021_Training_Data'
subfolders = [f for f in os.listdir(parent_folder_path) if os.path.isdir(os.path.join(parent_folder_path, f))]
num_folders = len(subfolders)
print(f"Số lượng mẫu trong '{parent_folder_path}' là: {num_folders}")

Số lượng mẫu trong '/app/brats_2021_task1/BraTS2021_Training_Data' là: 1251


In [None]:
import os
import json

folder_data = []

for fold_number in os.listdir(parent_folder_path):
 fold_path = os.path.join(parent_folder_path, fold_number)

 if os.path.isdir(fold_path):
 entry = {"fold": 0, "image": [], "label": ""}

 for file_type in ['flair', 't1ce', 't1', 't2']:
 file_name = f"{fold_number}_{file_type}.nii.gz"
 file_path = os.path.join(fold_path, file_name)

 if os.path.exists(file_path):

 entry["image"].append(os.path.abspath(file_path))

 label_name = f"{fold_number}_seg.nii.gz"
 label_path = os.path.join(fold_path, label_name)
 if os.path.exists(label_path):
 entry["label"] = os.path.abspath(label_path)

 folder_data.append(entry)


json_data = {"training": folder_data}

json_file_path = '/app/info.json'
with open(json_file_path, 'w') as json_file:
 json.dump(json_data, json_file, indent=2)

print(f"Thông tin đã được ghi vào {json_file_path}")


In [5]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
 """
 Convert labels to multi channels based on brats classes:
 label 1 is the necrotic and non-enhancing tumor core
 label 2 is the peritumoral edema
 label 4 is the GD-enhancing tumor
 The possible classes are TC (Tumor core), WT (Whole tumor)
 and ET (Enhancing tumor).

 """

 def __call__(self, data):
 d = dict(data)
 for key in self.keys:
 result = []
 # merge label 1 and label 4 to construct TC
 result.append(np.logical_or(d[key] == 1, d[key] == 4))
 # merge labels 1, 2 and 4 to construct WT
 result.append(
 np.logical_or(
 np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2
 )
 )
 # label 4 is ET
 result.append(d[key] == 4)
 d[key] = np.stack(result, axis=0).astype(np.float32)
 return d

In [6]:
def datafold_read(datalist, basedir, fold=0, key="training"):
 with open(datalist) as f:
 json_data = json.load(f)

 json_data = json_data[key]

 for d in json_data:
 for k in d:
 if isinstance(d[k], list):
 d[k] = [os.path.join(basedir, iv) for iv in d[k]]
 elif isinstance(d[k], str):
 d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]

 tr = []
 val = []
 for d in json_data:
 if "fold" in d and d["fold"] == fold:
 val.append(d)
 else:
 tr.append(d)

 return tr, val

In [7]:
def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :
 train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)
 from sklearn.model_selection import train_test_split
 if volume != None :
 train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)
 
 train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)
 
 validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)
 return train_files, validation_files, test_files

In [8]:
def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):
 train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)
 
 train_transform = transforms.Compose(
 [
 transforms.LoadImaged(keys=["image", "label"]),
 transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
 transforms.CropForegroundd(
 keys=["image", "label"],
 source_key="image",
 k_divisible=[roi[0], roi[1], roi[2]],
 ),
 transforms.RandSpatialCropd(
 keys=["image", "label"],
 roi_size=[roi[0], roi[1], roi[2]],
 random_size=False,
 ),
 transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
 transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
 transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
 transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
 transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
 transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
 ]
 )
 val_transform = transforms.Compose(
 [
 transforms.LoadImaged(keys=["image", "label"]),
 transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
 transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
 ]
 )

 train_ds = Dataset(data=train_files, transform=train_transform)
 train_loader = DataLoader(
 train_ds,
 batch_size=batch_size,
 shuffle=True,
 num_workers=2,
 pin_memory=True,
 )
 val_ds = Dataset(data=validation_files, transform=val_transform)
 val_loader = DataLoader(
 val_ds,
 batch_size=1,
 shuffle=False,
 num_workers=2,
 pin_memory=True,
 )
 test_ds = Dataset(data=test_files, transform=val_transform)
 test_loader = DataLoader(
 test_ds,
 batch_size=1,
 shuffle=False,
 num_workers=2,
 pin_memory=True,
 )
 return train_loader, val_loader,test_loader

In [9]:
import json
data_dir = "/app/brats_2021_task1"
json_list = "/app/info.json"
roi = (128, 128, 128)
batch_size = 1
sw_batch_size = 2
fold = 1
infer_overlap = 0.5
max_epochs = 100
val_every = 10
train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=0.5, test_size=0.2)



In [45]:
len(val_loader)

100

In [10]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Model design, base on SegResNet, VAE and TransBTS

In [11]:
import torch
import torch.nn as nn

#Re-use from encoder block
def normalization(planes, norm = 'instance'):
 if norm == 'bn':
 m = nn.BatchNorm3d(planes)
 elif norm == 'gn':
 m = nn.GroupNorm(8, planes)
 elif norm == 'instance':
 m = nn.InstanceNorm3d(planes)
 else:
 raise ValueError("Does not support this kind of norm.")
 return m
class ResNetBlock(nn.Module):
 def __init__(self, in_channels, norm = 'instance'):
 super().__init__()
 self.resnetblock = nn.Sequential(
 normalization(in_channels, norm = norm),
 nn.LeakyReLU(0.2, inplace=True),
 nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),
 normalization(in_channels, norm = norm),
 nn.LeakyReLU(0.2, inplace=True),
 nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)
 )

 def forward(self, x):
 y = self.resnetblock(x)
 return y + x

In [12]:


from torch.nn import functional as F

def calculate_total_dimension(a):
 res = 1
 for x in a:
 res *= x
 return res

class VAE(nn.Module):
 def __init__(self, input_shape, latent_dim, num_channels):
 super().__init__()
 self.input_shape = input_shape
 self.in_channels = input_shape[1] #input_shape[0] is batch size
 self.latent_dim = latent_dim
 self.encoder_channels = self.in_channels // 16

 #Encoder
 self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,
 kernel_size = 3, stride = 2, padding=1)
 # self.VAE_reshape = nn.Sequential(
 # nn.GroupNorm(8, self.in_channels),
 # nn.ReLU(),
 # nn.Conv3d(self.in_channels, self.encoder_channels,
 # kernel_size = 3, stride = 2, padding=1),
 # )

 flatten_input_shape = calculate_total_dimension(input_shape)
 flatten_input_shape_after_vae_reshape = \
 flatten_input_shape * self.encoder_channels // (8 * self.in_channels)

 #Convert from total dimension to latent space
 self.to_latent_space = nn.Linear(
 flatten_input_shape_after_vae_reshape // self.in_channels, 1)

 self.mean = nn.Linear(self.in_channels, self.latent_dim)
 self.logvar = nn.Linear(self.in_channels, self.latent_dim)
# self.epsilon = nn.Parameter(torch.randn(1, latent_dim))

 #Decoder
 self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)
 self.Reconstruct = nn.Sequential(
 nn.LeakyReLU(0.2, inplace=True),
 nn.Conv3d(
 self.encoder_channels, self.in_channels,
 stride = 1, kernel_size = 1),
 nn.Upsample(scale_factor=2, mode = 'nearest'),

 nn.Conv3d(
 self.in_channels, self.in_channels // 2,
 stride = 1, kernel_size = 1),
 nn.Upsample(scale_factor=2, mode = 'nearest'),
 ResNetBlock(self.in_channels // 2),

 nn.Conv3d(
 self.in_channels // 2, self.in_channels // 4,
 stride = 1, kernel_size = 1),
 nn.Upsample(scale_factor=2, mode = 'nearest'),
 ResNetBlock(self.in_channels // 4),

 nn.Conv3d(
 self.in_channels // 4, self.in_channels // 8,
 stride = 1, kernel_size = 1),
 nn.Upsample(scale_factor=2, mode = 'nearest'),
 ResNetBlock(self.in_channels // 8),

 nn.InstanceNorm3d(self.in_channels // 8),
 nn.LeakyReLU(0.2, inplace=True),
 nn.Conv3d(
 self.in_channels // 8, num_channels,
 kernel_size = 3, padding = 1),
# nn.Sigmoid()
 )


 def forward(self, x): #x has shape = input_shape
 #Encoder
 # print(x.shape)
 x = self.VAE_reshape(x)
 shape = x.shape

 x = x.view(self.in_channels, -1)
 x = self.to_latent_space(x)
 x = x.view(1, self.in_channels)

 mean = self.mean(x)
 logvar = self.logvar(x)
# sigma = torch.exp(0.5 * logvar)
 # Reparameter
 epsilon = torch.randn_like(logvar)
 sample = mean + epsilon * torch.exp(0.5*logvar)

 #Decoder
 y = self.to_original_dimension(sample)
 y = y.view(*shape)
 return self.Reconstruct(y), mean, logvar
 def total_params(self):
 total = sum(p.numel() for p in self.parameters())
 return format(total, ',')

 def total_trainable_params(self):
 total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
 return format(total_trainable, ',')


# x = torch.rand((1, 256, 16, 16, 16))
# vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)
# y = vae(x)
# print(y[0].shape, y[1].shape, y[2].shape)
# print(vae.total_trainable_params())


In [13]:
import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

def pair(t):
 return t if isinstance(t, tuple) else (t, t)


class PreNorm(nn.Module):
 def __init__(self, dim, function):
 super().__init__()
 self.norm = nn.LayerNorm(dim)
 self.function = function

 def forward(self, x):
 return self.function(self.norm(x))


class FeedForward(nn.Module):
 def __init__(self, dim, hidden_dim, dropout = 0.0):
 super().__init__()
 self.net = nn.Sequential(
 nn.Linear(dim, hidden_dim),
 nn.GELU(),
 nn.Dropout(dropout),
 nn.Linear(hidden_dim, dim),
 nn.Dropout(dropout)
 )

 def forward(self, x):
 return self.net(x)

class Attention(nn.Module):
 def __init__(self, dim, heads, dim_head, dropout = 0.0):
 super().__init__()
 all_head_size = heads * dim_head
 project_out = not (heads == 1 and dim_head == dim)

 self.heads = heads
 self.scale = dim_head ** -0.5

 self.softmax = nn.Softmax(dim = -1)
 self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)

 self.to_out = nn.Sequential(
 nn.Linear(all_head_size, dim),
 nn.Dropout(dropout)
 ) if project_out else nn.Identity()

 def forward(self, x):
 qkv = self.to_qkv(x).chunk(3, dim = -1)
 #(batch, heads * dim_head) -> (batch, all_head_size)
 q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

 dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

 atten = self.softmax(dots)

 out = torch.matmul(atten, v)
 out = rearrange(out, 'b h n d -> b n (h d)')
 return self.to_out(out)

class Transformer(nn.Module):
 def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):
 super().__init__()
 self.layers = nn.ModuleList([])
 for _ in range(depth):
 self.layers.append(nn.ModuleList([
 PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
 PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
 ]))
 def forward(self, x):
 for attention, feedforward in self.layers:
 x = attention(x) + x
 x = feedforward(x) + x
 return x

class FixedPositionalEncoding(nn.Module):
 def __init__(self, embedding_dim, max_length=768):
 super(FixedPositionalEncoding, self).__init__()

 pe = torch.zeros(max_length, embedding_dim)
 position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
 div_term = torch.exp(
 torch.arange(0, embedding_dim, 2).float()
 * (-torch.log(torch.tensor(10000.0)) / embedding_dim)
 )
 pe[:, 0::2] = torch.sin(position * div_term)
 pe[:, 1::2] = torch.cos(position * div_term)
 pe = pe.unsqueeze(0).transpose(0, 1)
 self.register_buffer('pe', pe)

 def forward(self, x):
 x = x + self.pe[: x.size(0), :]
 return x


class LearnedPositionalEncoding(nn.Module):
 def __init__(self, embedding_dim, seq_length):
 super(LearnedPositionalEncoding, self).__init__()
 self.seq_length = seq_length
 self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x

 def forward(self, x, position_ids=None):
 position_embeddings = self.position_embeddings
# print(x.shape, self.position_embeddings.shape)
 return x + position_embeddings

In [14]:
### Encoder ####
import torch.nn as nn
import torch.nn.functional as F

class InitConv(nn.Module):
 def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):
 super().__init__()
 self.layer = nn.Sequential(
 nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),
 nn.Dropout3d(dropout)
 )
 def forward(self, x):
 y = self.layer(x)
 return y


class DownSample(nn.Module):
 def __init__(self, in_channels, out_channels):
 super().__init__()
 self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)
 def forward(self, x):
 return self.conv(x)

class Encoder(nn.Module):
 def __init__(self, in_channels, base_channels, dropout = 0.2):
 super().__init__()

 self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)
 self.encoder_block1 = ResNetBlock(in_channels = base_channels)
 self.encoder_down1 = DownSample(base_channels, base_channels * 2)

 self.encoder_block2_1 = ResNetBlock(base_channels * 2)
 self.encoder_block2_2 = ResNetBlock(base_channels * 2)
 self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)

 self.encoder_block3_1 = ResNetBlock(base_channels * 4)
 self.encoder_block3_2 = ResNetBlock(base_channels * 4)
 self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)

 self.encoder_block4_1 = ResNetBlock(base_channels * 8)
 self.encoder_block4_2 = ResNetBlock(base_channels * 8)
 self.encoder_block4_3 = ResNetBlock(base_channels * 8)
 self.encoder_block4_4 = ResNetBlock(base_channels * 8)
 # self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)
 def forward(self, x):
 x = self.init_conv(x) #(1, 16, 128, 128, 128)

 x1 = self.encoder_block1(x)
 x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)

 x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))
 x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)

 x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))
 x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)

 output = self.encoder_block4_4(
 self.encoder_block4_3(
 self.encoder_block4_2(
 self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)
 return x1, x2, x3, output

# x = torch.rand((1, 4, 128, 128, 128))
# Enc = Encoder(4, 32)
# _, _, _, y = Enc(x)
# print(y.shape) (1,256,16,16)

In [15]:
### Decoder ####

import torch
import torch.nn as nn


class Upsample(nn.Module):
 def __init__(self, in_channel, out_channel):
 super().__init__()
 self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)
 self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)
 self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)

 def forward(self, prev, x):
 x = self.deconv(self.conv1(x))
 y = torch.cat((prev, x), dim = 1)
 return self.conv2(y)

class FinalConv(nn.Module): # Input channels are equal to output channels
 def __init__(self, in_channels, out_channels=32, norm="instance"):
 super(FinalConv, self).__init__()
 if norm == "batch":
 norm_layer = nn.BatchNorm3d(num_features=in_channels)
 elif norm == "group":
 norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)
 elif norm == 'instance':
 norm_layer = nn.InstanceNorm3d(in_channels)

 self.layer = nn.Sequential(
 norm_layer,
 nn.LeakyReLU(0.2, inplace=True),
 nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
 )
 def forward(self, x):
 return self.layer(x)

class Decoder(nn.Module):
 def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):
 super().__init__()
 self.img_dim = img_dim
 self.patch_dim = patch_dim
 self.embedding_dim = embedding_dim

 self.decoder_upsample_1 = Upsample(128, 64)
 self.decoder_block_1 = ResNetBlock(64)

 self.decoder_upsample_2 = Upsample(64, 32)
 self.decoder_block_2 = ResNetBlock(32)

 self.decoder_upsample_3 = Upsample(32, 16)
 self.decoder_block_3 = ResNetBlock(16)

 self.endconv = FinalConv(16, num_classes)
 # self.normalize = nn.Sigmoid()

 def forward(self, x1, x2, x3, x):
 x = self.decoder_upsample_1(x3, x)
 x = self.decoder_block_1(x)

 x = self.decoder_upsample_2(x2, x)
 x = self.decoder_block_2(x)

 x = self.decoder_upsample_3(x1, x)
 x = self.decoder_block_3(x)

 y = self.endconv(x)
 return y

In [16]:
class FeatureMapping(nn.Module):
 def __init__(self, in_channel, out_channel, norm = 'instance'):
 super().__init__()
 if norm == 'bn':
 norm_layer_1 = nn.BatchNorm3d(out_channel)
 norm_layer_2 = nn.BatchNorm3d(out_channel)
 elif norm == 'gn':
 norm_layer_1 = nn.GroupNorm(8, out_channel)
 norm_layer_2 = nn.GroupNorm(8, out_channel)
 elif norm == 'instance':
 norm_layer_1 = nn.InstanceNorm3d(out_channel)
 norm_layer_2 = nn.InstanceNorm3d(out_channel)
 self.feature_mapping = nn.Sequential(
 nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),
 norm_layer_1,
 nn.LeakyReLU(0.2, inplace=True),
 nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),
 norm_layer_2,
 nn.LeakyReLU(0.2, inplace=True)
 )

 def forward(self, x):
 return self.feature_mapping(x)


class FeatureMapping1(nn.Module):
 def __init__(self, in_channel, norm = 'instance'):
 super().__init__()
 if norm == 'bn':
 norm_layer_1 = nn.BatchNorm3d(in_channel)
 norm_layer_2 = nn.BatchNorm3d(in_channel)
 elif norm == 'gn':
 norm_layer_1 = nn.GroupNorm(8, in_channel)
 norm_layer_2 = nn.GroupNorm(8, in_channel)
 elif norm == 'instance':
 norm_layer_1 = nn.InstanceNorm3d(in_channel)
 norm_layer_2 = nn.InstanceNorm3d(in_channel)
 self.feature_mapping1 = nn.Sequential(
 nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
 norm_layer_1,
 nn.LeakyReLU(0.2, inplace=True),
 nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),
 norm_layer_2,
 nn.LeakyReLU(0.2, inplace=True)
 )
 def forward(self, x):
 y = self.feature_mapping1(x)
 return x + y #Resnet Like

In [17]:

class SegTransVAE(nn.Module):
 def __init__(self, img_dim, patch_dim, num_channels, num_classes,
 embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,
 dropout = 0.0, attention_dropout = 0.0,
 conv_patch_representation = True, positional_encoding = 'learned',
 use_VAE = False):

 super().__init__()
 assert embedding_dim % num_heads == 0
 assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0

 self.img_dim = img_dim
 self.embedding_dim = embedding_dim
 self.num_heads = num_heads
 self.num_classes = num_classes
 self.patch_dim = patch_dim
 self.num_channels = num_channels
 self.in_channels_vae = in_channels_vae
 self.dropout = dropout
 self.attention_dropout = attention_dropout
 self.conv_patch_representation = conv_patch_representation
 self.use_VAE = use_VAE

 self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))
 self.seq_length = self.num_patches
 self.flatten_dim = 128 * num_channels

 self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
 if positional_encoding == "learned":
 self.position_encoding = LearnedPositionalEncoding(
 self.embedding_dim, self.seq_length
 )
 elif positional_encoding == "fixed":
 self.position_encoding = FixedPositionalEncoding(
 self.embedding_dim,
 )
 self.pe_dropout = nn.Dropout(self.dropout)

 self.transformer = Transformer(
 embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout
 )
 self.pre_head_ln = nn.LayerNorm(embedding_dim)

 if self.conv_patch_representation:
 self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)
 self.encoder = Encoder(self.num_channels, 16)
 self.bn = nn.InstanceNorm3d(128)
 self.relu = nn.LeakyReLU(0.2, inplace=True)
 self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)
 self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)
 self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)

 self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)
 if use_VAE:
 self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)
 def encode(self, x):
 if self.conv_patch_representation:
 x1, x2, x3, x = self.encoder(x)
 x = self.bn(x)
 x = self.relu(x)
 x = self.conv_x(x)
 x = x.permute(0, 2, 3, 4, 1).contiguous()
 x = x.view(x.size(0), -1, self.embedding_dim)
 x = self.position_encoding(x)
 x = self.pe_dropout(x)
 x = self.transformer(x)
 x = self.pre_head_ln(x)

 return x1, x2, x3, x

 def decode(self, x1, x2, x3, x):
 #x: (1, 4096, 512) -> (1, 16, 16, 16, 512)
# print("In decode...")
# print(" x1: {} \n x2: {} \n x3: {} \n x: {}".format( x1.shape, x2.shape, x3.shape, x.shape))
# break
 return self.decoder(x1, x2, x3, x)

 def forward(self, x, is_validation = True):
 x1, x2, x3, x = self.encode(x)
 x = x.view( x.size(0),
 self.img_dim[0] // self.patch_dim,
 self.img_dim[1] // self.patch_dim,
 self.img_dim[2] // self.patch_dim,
 self.embedding_dim)
 x = x.permute(0, 4, 1, 2, 3).contiguous()
 x = self.FeatureMapping(x)
 x = self.FeatureMapping1(x)
 if self.use_VAE and not is_validation:
 vae_out, mu, sigma = self.vae(x)
 y = self.decode(x1, x2, x3, x)
 if self.use_VAE and not is_validation:
 return y, vae_out, mu, sigma
 else:
 return y




In [18]:
import torch

# Check if CUDA (GPU support) is available
if torch.cuda.is_available():
 device = torch.device("cuda:0")
 print("CUDA (GPU) is available. Using GPU.")
else:
 device = torch.device("cpu")
 print("CUDA (GPU) is not available. Using CPU.")

CUDA (GPU) is available. Using GPU.


In [18]:
model = SegTransVAE(img_dim = (128, 128, 128),patch_dim= 8,num_channels =4,num_classes= 3,embedding_dim= 768,num_heads= 8,num_layers= 4, hidden_dim= 3072,in_channels_vae=128 , use_VAE = True)

In [28]:
total_params = sum(p.numel() for p in model.parameters())
print(f'Tổng số tham số của mô hình là: {total_params}')

total_params_requires_grad = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Tổng số tham số cần tính gradient của mô hình là: {total_params_requires_grad}')


Tổng số tham số của mô hình là: 44727120
Tổng số tham số cần tính gradient của mô hình là: 44727120


In [19]:
class Loss_VAE(nn.Module):
 def __init__(self):
 super().__init__()
 self.mse = nn.MSELoss(reduction='sum')

 def forward(self, recon_x, x, mu, log_var):
 mse = self.mse(recon_x, x)
 kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
 loss = mse + kld
 return loss

In [20]:
def DiceScore(
 y_pred: torch.Tensor,
 y: torch.Tensor,
 include_background: bool = True,
) -> torch.Tensor:
 """Computes Dice score metric from full size Tensor and collects average.
 Args:
 y_pred: input data to compute, typical segmentation model output.
 It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
 should be binarized.
 y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
 The values should be binarized.
 include_background: whether to skip Dice computation on the first channel of
 the predicted output. Defaults to True.
 Returns:
 Dice scores per batch and per class, (shape [batch_size, num_classes]).
 Raises:
 ValueError: when `y_pred` and `y` have different shapes.
 """

 y = y.float()
 y_pred = y_pred.float()

 if y.shape != y_pred.shape:
 raise ValueError("y_pred and y should have same shapes.")

 # reducing only spatial dimensions (not batch nor channels)
 n_len = len(y_pred.shape)
 reduce_axis = list(range(2, n_len))
 intersection = torch.sum(y * y_pred, dim=reduce_axis)

 y_o = torch.sum(y, reduce_axis)
 y_pred_o = torch.sum(y_pred, dim=reduce_axis)
 denominator = y_o + y_pred_o

 return torch.where(
 denominator > 0,
 (2.0 * intersection) / denominator,
 torch.tensor(float("1"), device=y_o.device),
 )


In [21]:
# Pytorch Lightning
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import csv
from monai.transforms import AsDiscrete, Activations, Compose, EnsureType

In [24]:
class BRATS(pl.LightningModule):
 def __init__(self, use_VAE = True, lr = 1e-4, ):
 super().__init__()
 
 self.use_vae = use_VAE
 self.lr = lr
 self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE)

 self.loss_vae = Loss_VAE()
 self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)
 self.post_trans_images = Compose(
 [EnsureType(),
 Activations(sigmoid=True), 
 AsDiscrete(threshold_values=True), 
 ]
 )

 self.best_val_dice = 0
 
 self.training_step_outputs = [] 
 self.val_step_loss = [] 
 self.val_step_dice = []
 self.val_step_dice_tc = [] 
 self.val_step_dice_wt = []
 self.val_step_dice_et = [] 
 self.test_step_loss = [] 
 self.test_step_dice = []
 self.test_step_dice_tc = [] 
 self.test_step_dice_wt = []
 self.test_step_dice_et = [] 

 def forward(self, x, is_validation = True):
 return self.model(x, is_validation) 
 def training_step(self, batch, batch_index):
 inputs, labels = (batch['image'], batch['label'])
 
 if not self.use_vae:
 outputs = self.forward(inputs, is_validation=False)
 loss = self.dice_loss(outputs, labels)
 else:
 outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False)
 
 vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma)
 dice_loss = self.dice_loss(outputs, labels)
 loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss
 self.training_step_outputs.append(loss)
 self.log('train/vae_loss', vae_loss)
 self.log('train/dice_loss', dice_loss)
 if batch_index == 10:

 tensorboard = self.logger.experiment 
 fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5))
 

 ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray')
 ax[0].set_title("Input")

 ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
 ax[1].set_title("Reconstruction")
 
 ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray')
 ax[2].set_title("Labels TC")
 
 ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray')
 ax[3].set_title("TC")
 
 ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray')
 ax[4].set_title("Labels ET")
 
 ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray')
 ax[5].set_title("ET")

 
 tensorboard.add_figure('train_visualize', fig, self.current_epoch)

 self.log('train/loss', loss)
 
 return loss
 
 def on_train_epoch_end(self):
 ## F1 Macro all epoch saving outputs and target per batch

 # free up the memory
 # --> HERE STEP 3 <--
 epoch_average = torch.stack(self.training_step_outputs).mean()
 self.log("training_epoch_average", epoch_average)
 self.training_step_outputs.clear() # free memory

 def validation_step(self, batch, batch_index):
 inputs, labels = (batch['image'], batch['label'])
 roi_size = (128, 128, 128)
 sw_batch_size = 1
 outputs = sliding_window_inference(
 inputs, roi_size, sw_batch_size, self.model, overlap = 0.5)
 loss = self.dice_loss(outputs, labels)
 
 
 val_outputs = self.post_trans_images(outputs)
 
 
 metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
 metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
 metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
 mean_val_dice = (metric_tc + metric_wt + metric_et)/3
 self.val_step_loss.append(loss) 
 self.val_step_dice.append(mean_val_dice)
 self.val_step_dice_tc.append(metric_tc) 
 self.val_step_dice_wt.append(metric_wt)
 self.val_step_dice_et.append(metric_et) 
 return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc,
 'val_dice_wt': metric_wt, 'val_dice_et': metric_et}
 
 def on_validation_epoch_end(self):

 loss = torch.stack(self.val_step_loss).mean()
 mean_val_dice = torch.stack(self.val_step_dice).mean()
 metric_tc = torch.stack(self.val_step_dice_tc).mean()
 metric_wt = torch.stack(self.val_step_dice_wt).mean()
 metric_et = torch.stack(self.val_step_dice_et).mean()
 self.log('val/Loss', loss)
 self.log('val/MeanDiceScore', mean_val_dice)
 self.log('val/DiceTC', metric_tc)
 self.log('val/DiceWT', metric_wt)
 self.log('val/DiceET', metric_et)
 os.makedirs(self.logger.log_dir, exist_ok=True)
 if self.current_epoch == 0:
 with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f:
 writer = csv.writer(f)
 writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET'])
 with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f:
 writer = csv.writer(f)
 writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()])

 if mean_val_dice > self.best_val_dice:
 self.best_val_dice = mean_val_dice
 self.best_val_epoch = self.current_epoch
 print(
 f"\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}"
 f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
 f"\n Best mean dice: {self.best_val_dice}"
 f" at epoch: {self.best_val_epoch}"
 )
 
 self.val_step_loss.clear() 
 self.val_step_dice.clear()
 self.val_step_dice_tc.clear() 
 self.val_step_dice_wt.clear()
 self.val_step_dice_et.clear()
 return {'val_MeanDiceScore': mean_val_dice}
 def test_step(self, batch, batch_index):
 inputs, labels = (batch['image'], batch['label'])
 
 roi_size = (128, 128, 128)
 sw_batch_size = 1
 test_outputs = sliding_window_inference(
 inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5)
 loss = self.dice_loss(test_outputs, labels)
 test_outputs = self.post_trans_images(test_outputs)
 metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)
 metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)
 metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)
 mean_test_dice = (metric_tc + metric_wt + metric_et)/3
 
 self.test_step_loss.append(loss) 
 self.test_step_dice.append(mean_test_dice)
 self.test_step_dice_tc.append(metric_tc) 
 self.test_step_dice_wt.append(metric_wt)
 self.test_step_dice_et.append(metric_et) 
 
 return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc,
 'test_dice_wt': metric_wt, 'test_dice_et': metric_et}
 
 def test_epoch_end(self):
 loss = torch.stack(self.test_step_loss).mean()
 mean_test_dice = torch.stack(self.test_step_dice).mean()
 metric_tc = torch.stack(self.test_step_dice_tc).mean()
 metric_wt = torch.stack(self.test_step_dice_wt).mean()
 metric_et = torch.stack(self.test_step_dice_et).mean()
 self.log('test/Loss', loss)
 self.log('test/MeanDiceScore', mean_test_dice)
 self.log('test/DiceTC', metric_tc)
 self.log('test/DiceWT', metric_wt)
 self.log('test/DiceET', metric_et)

 with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f:
 writer = csv.writer(f)
 writer.writerow(["Mean Test Dice", "Dice TC", "Dice WT", "Dice ET"])
 writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et])

 self.test_step_loss.clear() 
 self.test_step_dice.clear()
 self.test_step_dice_tc.clear() 
 self.test_step_dice_wt.clear()
 self.test_step_dice_et.clear()
 return {'test_MeanDiceScore': mean_test_dice}
 
 
 def configure_optimizers(self):
 optimizer = torch.optim.Adam(
 self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True
 )
# optimizer = AdaBelief(self.model.parameters(), 
# lr=self.lr, eps=1e-16, 
# betas=(0.9,0.999), weight_decouple = True, 
# rectify = False)
 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)
 return [optimizer], [scheduler]
 
 def train_dataloader(self):
 return train_loader
 def val_dataloader(self):
 return val_loader
 
 def test_dataloader(self):
 return test_loader

In [1]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import os 
from pytorch_lightning.loggers import TensorBoardLogger

 from .autonotebook import tqdm as notebook_tqdm


In [25]:
os.system('cls||clear')
print("Training ...")
model = BRATS(use_VAE = True)
checkpoint_callback = ModelCheckpoint(
 monitor='val/MeanDiceScore',
 dirpath='./app/checkpoints/{}'.format(1),
 filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',
 save_top_k=3,
 mode='max',
 save_last= True,
 auto_insert_metric_name=False
)
early_stop_callback = EarlyStopping(
 monitor='val/MeanDiceScore',
 min_delta=0.0001,
 patience=15,
 verbose=False,
 mode='max'
)
tensorboardlogger = TensorBoardLogger(
 'logs', 
 name = "1", 
 default_hp_metric = None 
)
trainer = pl.Trainer(#fast_dev_run = 10, 
# accelerator='ddp',
 #overfit_batches=5,
 devices = [0], 
 precision=16,
 max_epochs = 200, 
 enable_progress_bar=True, 
 callbacks=[checkpoint_callback, early_stop_callback], 
# auto_lr_find=True,
 num_sanity_val_steps=2,
 logger = tensorboardlogger,
# limit_train_batches=0.01, 
# limit_val_batches=0.01
 )
# trainer.tune(model)
trainer.fit(model)





sh: 1: cls: not found


[H[2JTraining ...


/usr/local/lib/python3.9/site-packages/lightning_fabric/connector.py:563: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

 | Name | Type | Params
------------------------------------------
0 | model | SegTransVAE | 44.7 M
1 | loss_vae | Loss_VAE | 0 
2 | dice_loss | DiceLoss | 0 
------------------------------------------
44.7 M Trainable params
0 Non-trainable params
44.7 M Total params
178.908 Total estimated model params size (MB)


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:05<00:00, 0.37it/s]
 Current epoch: 0 Current mean dice: 0.0097 tc: 0.0029 wt: 0.0234 et: 0.0028
 Best mean dice: 0.009687595069408417 at epoch: 0
Epoch 0: 100%|██████████| 500/500 [05:38<00:00, 1.48it/s, v_num=6] 
 Current epoch: 0 Current mean dice: 0.1927 tc: 0.1647 wt: 0.2843 et: 0.1290
 Best mean dice: 0.1926589012145996 at epoch: 0
Epoch 1: 100%|██████████| 500/500 [07:35<00:00, 1.10it/s, v_num=6]
 Current epoch: 1 Current mean dice: 0.3212 tc: 0.2691 wt: 0.4253 et: 0.2692
 Best mean dice: 0.32120221853256226 at epoch: 1
Epoch 2: 100%|██████████| 500/500 [08:11<00:00, 1.02it/s, v_num=6]
 Current epoch: 2 Current mean dice: 0.3912 tc: 0.3510 wt: 0.5087 et: 0.3137
 Best mean dice: 0.39115065336227417 at epoch: 2
Epoch 3: 100%|██████████| 500/500 [08:58<00:00, 0.93it/s, v_num=6]
 Current epoch: 3 Current mean dice: 0.4268 tc: 0.3828 wt: 0.5424 et: 0.3553
 Best mean dice: 0.42682838439941406 at epoch: 3
Epoch 4: 41%|████▏ | 207/5

: 

In [None]:
import pytorch_lightning as pl
from trainer import BRATS
import os 
import torch
os.system('cls||clear')
print("Testing ...")

CKPT = ''
model = BRATS(use_VAE=True).load_from_checkpoint(CKPT).eval()
val_dataloader = get_val_dataloader()
test_dataloader = get_test_dataloader()
trainer = pl.Trainer(gpus = [0], precision=32, progress_bar_refresh_rate=10)

trainer.test(model, dataloaders = val_dataloader)
trainer.test(model, dataloaders = test_dataloader)

