{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MONAI version: 1.4.dev2409\n", "Numpy version: 1.26.2\n", "Pytorch version: 1.13.0+cu116\n", "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", "MONAI rev id: 46c1b228091283fba829280a5d747f4237f76ed0\n", "MONAI __file__: /usr/local/lib/python3.9/site-packages/monai/__init__.py\n", "\n", "Optional dependencies:\n", "Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.\n", "ITK version: NOT INSTALLED or UNKNOWN VERSION.\n", "Nibabel version: 5.2.1\n", "scikit-image version: NOT INSTALLED or UNKNOWN VERSION.\n", "scipy version: 1.11.4\n", "Pillow version: 10.1.0\n", "Tensorboard version: 2.16.2\n", "gdown version: NOT INSTALLED or UNKNOWN VERSION.\n", "TorchVision version: 0.14.0+cu116\n", "tqdm version: 4.66.1\n", "lmdb version: NOT INSTALLED or UNKNOWN VERSION.\n", "psutil version: 5.9.8\n", "pandas version: 2.2.1\n", "einops version: 0.7.0\n", "transformers version: 4.35.2\n", "mlflow version: NOT INSTALLED or UNKNOWN VERSION.\n", "pynrrd version: NOT INSTALLED or UNKNOWN VERSION.\n", "clearml version: NOT INSTALLED or UNKNOWN VERSION.\n", "\n", "For details about installing the optional dependencies, please visit:\n", " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", "\n" ] } ], "source": [ "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from monai.config import print_config\n", "from monai.losses import DiceLoss\n", "from monai.inferers import sliding_window_inference\n", "from monai.transforms import MapTransform\n", "from monai.data import DataLoader, Dataset\n", "from monai.utils import set_determinism\n", "from monai import transforms\n", "import torch\n", "\n", "print_config()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "set_determinism(seed=0)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Số lượng mẫu trong '/app/brats_2021_task1/BraTS2021_Training_Data' là: 1251\n" ] } ], "source": [ "import os\n", "\n", "parent_folder_path = '/app/brats_2021_task1/BraTS2021_Training_Data'\n", "subfolders = [f for f in os.listdir(parent_folder_path) if os.path.isdir(os.path.join(parent_folder_path, f))]\n", "num_folders = len(subfolders)\n", "print(f\"Số lượng mẫu trong '{parent_folder_path}' là: {num_folders}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "\n", "folder_data = []\n", "\n", "for fold_number in os.listdir(parent_folder_path):\n", " fold_path = os.path.join(parent_folder_path, fold_number)\n", "\n", " if os.path.isdir(fold_path):\n", " entry = {\"fold\": 0, \"image\": [], \"label\": \"\"}\n", "\n", " for file_type in ['flair', 't1ce', 't1', 't2']:\n", " file_name = f\"{fold_number}_{file_type}.nii.gz\"\n", " file_path = os.path.join(fold_path, file_name)\n", "\n", " if os.path.exists(file_path):\n", "\n", " entry[\"image\"].append(os.path.abspath(file_path))\n", "\n", " label_name = f\"{fold_number}_seg.nii.gz\"\n", " label_path = os.path.join(fold_path, label_name)\n", " if os.path.exists(label_path):\n", " entry[\"label\"] = os.path.abspath(label_path)\n", "\n", " folder_data.append(entry)\n", "\n", "\n", "json_data = {\"training\": folder_data}\n", "\n", "json_file_path = '/app/info.json'\n", "with open(json_file_path, 'w') as json_file:\n", " json.dump(json_data, json_file, indent=2)\n", "\n", "print(f\"Thông tin đã được ghi vào {json_file_path}\")\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):\n", " \"\"\"\n", " Convert labels to multi channels based on brats classes:\n", " label 1 is the necrotic and non-enhancing tumor core\n", " label 2 is the peritumoral edema\n", " label 4 is the GD-enhancing tumor\n", " The possible classes are TC (Tumor core), WT (Whole tumor)\n", " and ET (Enhancing tumor).\n", "\n", " \"\"\"\n", "\n", " def __call__(self, data):\n", " d = dict(data)\n", " for key in self.keys:\n", " result = []\n", " # merge label 1 and label 4 to construct TC\n", " result.append(np.logical_or(d[key] == 1, d[key] == 4))\n", " # merge labels 1, 2 and 4 to construct WT\n", " result.append(\n", " np.logical_or(\n", " np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2\n", " )\n", " )\n", " # label 4 is ET\n", " result.append(d[key] == 4)\n", " d[key] = np.stack(result, axis=0).astype(np.float32)\n", " return d" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def datafold_read(datalist, basedir, fold=0, key=\"training\"):\n", " with open(datalist) as f:\n", " json_data = json.load(f)\n", "\n", " json_data = json_data[key]\n", "\n", " for d in json_data:\n", " for k in d:\n", " if isinstance(d[k], list):\n", " d[k] = [os.path.join(basedir, iv) for iv in d[k]]\n", " elif isinstance(d[k], str):\n", " d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]\n", "\n", " tr = []\n", " val = []\n", " for d in json_data:\n", " if \"fold\" in d and d[\"fold\"] == fold:\n", " val.append(d)\n", " else:\n", " tr.append(d)\n", "\n", " return tr, val" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :\n", " train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)\n", " from sklearn.model_selection import train_test_split\n", " if volume != None :\n", " train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)\n", " \n", " train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)\n", " \n", " validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)\n", " return train_files, validation_files, test_files" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):\n", " train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)\n", " \n", " train_transform = transforms.Compose(\n", " [\n", " transforms.LoadImaged(keys=[\"image\", \"label\"]),\n", " transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", " transforms.CropForegroundd(\n", " keys=[\"image\", \"label\"],\n", " source_key=\"image\",\n", " k_divisible=[roi[0], roi[1], roi[2]],\n", " ),\n", " transforms.RandSpatialCropd(\n", " keys=[\"image\", \"label\"],\n", " roi_size=[roi[0], roi[1], roi[2]],\n", " random_size=False,\n", " ),\n", " transforms.RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=0),\n", " transforms.RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=1),\n", " transforms.RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=2),\n", " transforms.NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", " transforms.RandScaleIntensityd(keys=\"image\", factors=0.1, prob=1.0),\n", " transforms.RandShiftIntensityd(keys=\"image\", offsets=0.1, prob=1.0),\n", " ]\n", " )\n", " val_transform = transforms.Compose(\n", " [\n", " transforms.LoadImaged(keys=[\"image\", \"label\"]),\n", " transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", " transforms.NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", " ]\n", " )\n", "\n", " train_ds = Dataset(data=train_files, transform=train_transform)\n", " train_loader = DataLoader(\n", " train_ds,\n", " batch_size=batch_size,\n", " shuffle=True,\n", " num_workers=2,\n", " pin_memory=True,\n", " )\n", " val_ds = Dataset(data=validation_files, transform=val_transform)\n", " val_loader = DataLoader(\n", " val_ds,\n", " batch_size=1,\n", " shuffle=False,\n", " num_workers=2,\n", " pin_memory=True,\n", " )\n", " test_ds = Dataset(data=test_files, transform=val_transform)\n", " test_loader = DataLoader(\n", " test_ds,\n", " batch_size=1,\n", " shuffle=False,\n", " num_workers=2,\n", " pin_memory=True,\n", " )\n", " return train_loader, val_loader,test_loader" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.9/site-packages/monai/utils/deprecate_utils.py:321: FutureWarning: monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.\n", " warn_deprecated(argname, msg, warning_category)\n" ] } ], "source": [ "import json\n", "data_dir = \"/app/brats_2021_task1\"\n", "json_list = \"/app/info.json\"\n", "roi = (128, 128, 128)\n", "batch_size = 1\n", "sw_batch_size = 2\n", "fold = 1\n", "infer_overlap = 0.5\n", "max_epochs = 100\n", "val_every = 10\n", "train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=0.5, test_size=0.2)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "100" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(val_loader)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Model design, base on SegResNet, VAE and TransBTS" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "#Re-use from encoder block\n", "def normalization(planes, norm = 'instance'):\n", " if norm == 'bn':\n", " m = nn.BatchNorm3d(planes)\n", " elif norm == 'gn':\n", " m = nn.GroupNorm(8, planes)\n", " elif norm == 'instance':\n", " m = nn.InstanceNorm3d(planes)\n", " else:\n", " raise ValueError(\"Does not support this kind of norm.\")\n", " return m\n", "class ResNetBlock(nn.Module):\n", " def __init__(self, in_channels, norm = 'instance'):\n", " super().__init__()\n", " self.resnetblock = nn.Sequential(\n", " normalization(in_channels, norm = norm),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1),\n", " normalization(in_channels, norm = norm),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.Conv3d(in_channels, in_channels, kernel_size = 3, padding = 1)\n", " )\n", "\n", " def forward(self, x):\n", " y = self.resnetblock(x)\n", " return y + x" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "from torch.nn import functional as F\n", "\n", "def calculate_total_dimension(a):\n", " res = 1\n", " for x in a:\n", " res *= x\n", " return res\n", "\n", "class VAE(nn.Module):\n", " def __init__(self, input_shape, latent_dim, num_channels):\n", " super().__init__()\n", " self.input_shape = input_shape\n", " self.in_channels = input_shape[1] #input_shape[0] is batch size\n", " self.latent_dim = latent_dim\n", " self.encoder_channels = self.in_channels // 16\n", "\n", " #Encoder\n", " self.VAE_reshape = nn.Conv3d(self.in_channels, self.encoder_channels,\n", " kernel_size = 3, stride = 2, padding=1)\n", " # self.VAE_reshape = nn.Sequential(\n", " # nn.GroupNorm(8, self.in_channels),\n", " # nn.ReLU(),\n", " # nn.Conv3d(self.in_channels, self.encoder_channels,\n", " # kernel_size = 3, stride = 2, padding=1),\n", " # )\n", "\n", " flatten_input_shape = calculate_total_dimension(input_shape)\n", " flatten_input_shape_after_vae_reshape = \\\n", " flatten_input_shape * self.encoder_channels // (8 * self.in_channels)\n", "\n", " #Convert from total dimension to latent space\n", " self.to_latent_space = nn.Linear(\n", " flatten_input_shape_after_vae_reshape // self.in_channels, 1)\n", "\n", " self.mean = nn.Linear(self.in_channels, self.latent_dim)\n", " self.logvar = nn.Linear(self.in_channels, self.latent_dim)\n", "# self.epsilon = nn.Parameter(torch.randn(1, latent_dim))\n", "\n", " #Decoder\n", " self.to_original_dimension = nn.Linear(self.latent_dim, flatten_input_shape_after_vae_reshape)\n", " self.Reconstruct = nn.Sequential(\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.Conv3d(\n", " self.encoder_channels, self.in_channels,\n", " stride = 1, kernel_size = 1),\n", " nn.Upsample(scale_factor=2, mode = 'nearest'),\n", "\n", " nn.Conv3d(\n", " self.in_channels, self.in_channels // 2,\n", " stride = 1, kernel_size = 1),\n", " nn.Upsample(scale_factor=2, mode = 'nearest'),\n", " ResNetBlock(self.in_channels // 2),\n", "\n", " nn.Conv3d(\n", " self.in_channels // 2, self.in_channels // 4,\n", " stride = 1, kernel_size = 1),\n", " nn.Upsample(scale_factor=2, mode = 'nearest'),\n", " ResNetBlock(self.in_channels // 4),\n", "\n", " nn.Conv3d(\n", " self.in_channels // 4, self.in_channels // 8,\n", " stride = 1, kernel_size = 1),\n", " nn.Upsample(scale_factor=2, mode = 'nearest'),\n", " ResNetBlock(self.in_channels // 8),\n", "\n", " nn.InstanceNorm3d(self.in_channels // 8),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.Conv3d(\n", " self.in_channels // 8, num_channels,\n", " kernel_size = 3, padding = 1),\n", "# nn.Sigmoid()\n", " )\n", "\n", "\n", " def forward(self, x): #x has shape = input_shape\n", " #Encoder\n", " # print(x.shape)\n", " x = self.VAE_reshape(x)\n", " shape = x.shape\n", "\n", " x = x.view(self.in_channels, -1)\n", " x = self.to_latent_space(x)\n", " x = x.view(1, self.in_channels)\n", "\n", " mean = self.mean(x)\n", " logvar = self.logvar(x)\n", "# sigma = torch.exp(0.5 * logvar)\n", " # Reparameter\n", " epsilon = torch.randn_like(logvar)\n", " sample = mean + epsilon * torch.exp(0.5*logvar)\n", "\n", " #Decoder\n", " y = self.to_original_dimension(sample)\n", " y = y.view(*shape)\n", " return self.Reconstruct(y), mean, logvar\n", " def total_params(self):\n", " total = sum(p.numel() for p in self.parameters())\n", " return format(total, ',')\n", "\n", " def total_trainable_params(self):\n", " total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)\n", " return format(total_trainable, ',')\n", "\n", "\n", "# x = torch.rand((1, 256, 16, 16, 16))\n", "# vae = VAE(input_shape = x.shape, latent_dim = 256, num_channels = 4)\n", "# y = vae(x)\n", "# print(y[0].shape, y[1].shape, y[2].shape)\n", "# print(vae.total_trainable_params())\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "\n", "from einops import rearrange\n", "from einops.layers.torch import Rearrange\n", "\n", "def pair(t):\n", " return t if isinstance(t, tuple) else (t, t)\n", "\n", "\n", "class PreNorm(nn.Module):\n", " def __init__(self, dim, function):\n", " super().__init__()\n", " self.norm = nn.LayerNorm(dim)\n", " self.function = function\n", "\n", " def forward(self, x):\n", " return self.function(self.norm(x))\n", "\n", "\n", "class FeedForward(nn.Module):\n", " def __init__(self, dim, hidden_dim, dropout = 0.0):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(dim, hidden_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim, dim),\n", " nn.Dropout(dropout)\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "class Attention(nn.Module):\n", " def __init__(self, dim, heads, dim_head, dropout = 0.0):\n", " super().__init__()\n", " all_head_size = heads * dim_head\n", " project_out = not (heads == 1 and dim_head == dim)\n", "\n", " self.heads = heads\n", " self.scale = dim_head ** -0.5\n", "\n", " self.softmax = nn.Softmax(dim = -1)\n", " self.to_qkv = nn.Linear(dim, all_head_size * 3, bias = False)\n", "\n", " self.to_out = nn.Sequential(\n", " nn.Linear(all_head_size, dim),\n", " nn.Dropout(dropout)\n", " ) if project_out else nn.Identity()\n", "\n", " def forward(self, x):\n", " qkv = self.to_qkv(x).chunk(3, dim = -1)\n", " #(batch, heads * dim_head) -> (batch, all_head_size)\n", " q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)\n", "\n", " dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale\n", "\n", " atten = self.softmax(dots)\n", "\n", " out = torch.matmul(atten, v)\n", " out = rearrange(out, 'b h n d -> b n (h d)')\n", " return self.to_out(out)\n", "\n", "class Transformer(nn.Module):\n", " def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.0):\n", " super().__init__()\n", " self.layers = nn.ModuleList([])\n", " for _ in range(depth):\n", " self.layers.append(nn.ModuleList([\n", " PreNorm(dim, Attention(dim, heads, dim_head, dropout)),\n", " PreNorm(dim, FeedForward(dim, mlp_dim, dropout))\n", " ]))\n", " def forward(self, x):\n", " for attention, feedforward in self.layers:\n", " x = attention(x) + x\n", " x = feedforward(x) + x\n", " return x\n", "\n", "class FixedPositionalEncoding(nn.Module):\n", " def __init__(self, embedding_dim, max_length=768):\n", " super(FixedPositionalEncoding, self).__init__()\n", "\n", " pe = torch.zeros(max_length, embedding_dim)\n", " position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)\n", " div_term = torch.exp(\n", " torch.arange(0, embedding_dim, 2).float()\n", " * (-torch.log(torch.tensor(10000.0)) / embedding_dim)\n", " )\n", " pe[:, 0::2] = torch.sin(position * div_term)\n", " pe[:, 1::2] = torch.cos(position * div_term)\n", " pe = pe.unsqueeze(0).transpose(0, 1)\n", " self.register_buffer('pe', pe)\n", "\n", " def forward(self, x):\n", " x = x + self.pe[: x.size(0), :]\n", " return x\n", "\n", "\n", "class LearnedPositionalEncoding(nn.Module):\n", " def __init__(self, embedding_dim, seq_length):\n", " super(LearnedPositionalEncoding, self).__init__()\n", " self.seq_length = seq_length\n", " self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, embedding_dim)) #8x\n", "\n", " def forward(self, x, position_ids=None):\n", " position_embeddings = self.position_embeddings\n", "# print(x.shape, self.position_embeddings.shape)\n", " return x + position_embeddings" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "### Encoder ####\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class InitConv(nn.Module):\n", " def __init__(self, in_channels = 4, out_channels = 16, dropout = 0.2):\n", " super().__init__()\n", " self.layer = nn.Sequential(\n", " nn.Conv3d(in_channels, out_channels, kernel_size = 3, padding = 1),\n", " nn.Dropout3d(dropout)\n", " )\n", " def forward(self, x):\n", " y = self.layer(x)\n", " return y\n", "\n", "\n", "class DownSample(nn.Module):\n", " def __init__(self, in_channels, out_channels):\n", " super().__init__()\n", " self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)\n", " def forward(self, x):\n", " return self.conv(x)\n", "\n", "class Encoder(nn.Module):\n", " def __init__(self, in_channels, base_channels, dropout = 0.2):\n", " super().__init__()\n", "\n", " self.init_conv = InitConv(in_channels, base_channels, dropout = dropout)\n", " self.encoder_block1 = ResNetBlock(in_channels = base_channels)\n", " self.encoder_down1 = DownSample(base_channels, base_channels * 2)\n", "\n", " self.encoder_block2_1 = ResNetBlock(base_channels * 2)\n", " self.encoder_block2_2 = ResNetBlock(base_channels * 2)\n", " self.encoder_down2 = DownSample(base_channels * 2, base_channels * 4)\n", "\n", " self.encoder_block3_1 = ResNetBlock(base_channels * 4)\n", " self.encoder_block3_2 = ResNetBlock(base_channels * 4)\n", " self.encoder_down3 = DownSample(base_channels * 4, base_channels * 8)\n", "\n", " self.encoder_block4_1 = ResNetBlock(base_channels * 8)\n", " self.encoder_block4_2 = ResNetBlock(base_channels * 8)\n", " self.encoder_block4_3 = ResNetBlock(base_channels * 8)\n", " self.encoder_block4_4 = ResNetBlock(base_channels * 8)\n", " # self.encoder_down3 = EncoderDown(base_channels * 8, base_channels * 16)\n", " def forward(self, x):\n", " x = self.init_conv(x) #(1, 16, 128, 128, 128)\n", "\n", " x1 = self.encoder_block1(x)\n", " x1_down = self.encoder_down1(x1) #(1, 32, 64, 64, 64)\n", "\n", " x2 = self.encoder_block2_2(self.encoder_block2_1(x1_down))\n", " x2_down = self.encoder_down2(x2) #(1, 64, 32, 32, 32)\n", "\n", " x3 = self.encoder_block3_2(self.encoder_block3_1(x2_down))\n", " x3_down = self.encoder_down3(x3) #(1, 128, 16, 16, 16)\n", "\n", " output = self.encoder_block4_4(\n", " self.encoder_block4_3(\n", " self.encoder_block4_2(\n", " self.encoder_block4_1(x3_down)))) #(1, 256, 16, 16, 16)\n", " return x1, x2, x3, output\n", "\n", "# x = torch.rand((1, 4, 128, 128, 128))\n", "# Enc = Encoder(4, 32)\n", "# _, _, _, y = Enc(x)\n", "# print(y.shape) (1,256,16,16)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "### Decoder ####\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "\n", "class Upsample(nn.Module):\n", " def __init__(self, in_channel, out_channel):\n", " super().__init__()\n", " self.conv1 = nn.Conv3d(in_channel, out_channel, kernel_size = 1)\n", " self.deconv = nn.ConvTranspose3d(out_channel, out_channel, kernel_size = 2, stride = 2)\n", " self.conv2 = nn.Conv3d(out_channel * 2, out_channel, kernel_size = 1)\n", "\n", " def forward(self, prev, x):\n", " x = self.deconv(self.conv1(x))\n", " y = torch.cat((prev, x), dim = 1)\n", " return self.conv2(y)\n", "\n", "class FinalConv(nn.Module): # Input channels are equal to output channels\n", " def __init__(self, in_channels, out_channels=32, norm=\"instance\"):\n", " super(FinalConv, self).__init__()\n", " if norm == \"batch\":\n", " norm_layer = nn.BatchNorm3d(num_features=in_channels)\n", " elif norm == \"group\":\n", " norm_layer = nn.GroupNorm(num_groups=8, num_channels=in_channels)\n", " elif norm == 'instance':\n", " norm_layer = nn.InstanceNorm3d(in_channels)\n", "\n", " self.layer = nn.Sequential(\n", " norm_layer,\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)\n", " )\n", " def forward(self, x):\n", " return self.layer(x)\n", "\n", "class Decoder(nn.Module):\n", " def __init__(self, img_dim, patch_dim, embedding_dim, num_classes = 3):\n", " super().__init__()\n", " self.img_dim = img_dim\n", " self.patch_dim = patch_dim\n", " self.embedding_dim = embedding_dim\n", "\n", " self.decoder_upsample_1 = Upsample(128, 64)\n", " self.decoder_block_1 = ResNetBlock(64)\n", "\n", " self.decoder_upsample_2 = Upsample(64, 32)\n", " self.decoder_block_2 = ResNetBlock(32)\n", "\n", " self.decoder_upsample_3 = Upsample(32, 16)\n", " self.decoder_block_3 = ResNetBlock(16)\n", "\n", " self.endconv = FinalConv(16, num_classes)\n", " # self.normalize = nn.Sigmoid()\n", "\n", " def forward(self, x1, x2, x3, x):\n", " x = self.decoder_upsample_1(x3, x)\n", " x = self.decoder_block_1(x)\n", "\n", " x = self.decoder_upsample_2(x2, x)\n", " x = self.decoder_block_2(x)\n", "\n", " x = self.decoder_upsample_3(x1, x)\n", " x = self.decoder_block_3(x)\n", "\n", " y = self.endconv(x)\n", " return y" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "class FeatureMapping(nn.Module):\n", " def __init__(self, in_channel, out_channel, norm = 'instance'):\n", " super().__init__()\n", " if norm == 'bn':\n", " norm_layer_1 = nn.BatchNorm3d(out_channel)\n", " norm_layer_2 = nn.BatchNorm3d(out_channel)\n", " elif norm == 'gn':\n", " norm_layer_1 = nn.GroupNorm(8, out_channel)\n", " norm_layer_2 = nn.GroupNorm(8, out_channel)\n", " elif norm == 'instance':\n", " norm_layer_1 = nn.InstanceNorm3d(out_channel)\n", " norm_layer_2 = nn.InstanceNorm3d(out_channel)\n", " self.feature_mapping = nn.Sequential(\n", " nn.Conv3d(in_channel, out_channel, kernel_size = 3, padding = 1),\n", " norm_layer_1,\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.Conv3d(out_channel, out_channel, kernel_size = 3, padding = 1),\n", " norm_layer_2,\n", " nn.LeakyReLU(0.2, inplace=True)\n", " )\n", "\n", " def forward(self, x):\n", " return self.feature_mapping(x)\n", "\n", "\n", "class FeatureMapping1(nn.Module):\n", " def __init__(self, in_channel, norm = 'instance'):\n", " super().__init__()\n", " if norm == 'bn':\n", " norm_layer_1 = nn.BatchNorm3d(in_channel)\n", " norm_layer_2 = nn.BatchNorm3d(in_channel)\n", " elif norm == 'gn':\n", " norm_layer_1 = nn.GroupNorm(8, in_channel)\n", " norm_layer_2 = nn.GroupNorm(8, in_channel)\n", " elif norm == 'instance':\n", " norm_layer_1 = nn.InstanceNorm3d(in_channel)\n", " norm_layer_2 = nn.InstanceNorm3d(in_channel)\n", " self.feature_mapping1 = nn.Sequential(\n", " nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),\n", " norm_layer_1,\n", " nn.LeakyReLU(0.2, inplace=True),\n", " nn.Conv3d(in_channel, in_channel, kernel_size = 3, padding = 1),\n", " norm_layer_2,\n", " nn.LeakyReLU(0.2, inplace=True)\n", " )\n", " def forward(self, x):\n", " y = self.feature_mapping1(x)\n", " return x + y #Resnet Like" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "\n", "class SegTransVAE(nn.Module):\n", " def __init__(self, img_dim, patch_dim, num_channels, num_classes,\n", " embedding_dim, num_heads, num_layers, hidden_dim, in_channels_vae,\n", " dropout = 0.0, attention_dropout = 0.0,\n", " conv_patch_representation = True, positional_encoding = 'learned',\n", " use_VAE = False):\n", "\n", " super().__init__()\n", " assert embedding_dim % num_heads == 0\n", " assert img_dim[0] % patch_dim == 0 and img_dim[1] % patch_dim == 0 and img_dim[2] % patch_dim == 0\n", "\n", " self.img_dim = img_dim\n", " self.embedding_dim = embedding_dim\n", " self.num_heads = num_heads\n", " self.num_classes = num_classes\n", " self.patch_dim = patch_dim\n", " self.num_channels = num_channels\n", " self.in_channels_vae = in_channels_vae\n", " self.dropout = dropout\n", " self.attention_dropout = attention_dropout\n", " self.conv_patch_representation = conv_patch_representation\n", " self.use_VAE = use_VAE\n", "\n", " self.num_patches = int((img_dim[0] // patch_dim) * (img_dim[1] // patch_dim) * (img_dim[2] // patch_dim))\n", " self.seq_length = self.num_patches\n", " self.flatten_dim = 128 * num_channels\n", "\n", " self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)\n", " if positional_encoding == \"learned\":\n", " self.position_encoding = LearnedPositionalEncoding(\n", " self.embedding_dim, self.seq_length\n", " )\n", " elif positional_encoding == \"fixed\":\n", " self.position_encoding = FixedPositionalEncoding(\n", " self.embedding_dim,\n", " )\n", " self.pe_dropout = nn.Dropout(self.dropout)\n", "\n", " self.transformer = Transformer(\n", " embedding_dim, num_layers, num_heads, embedding_dim // num_heads, hidden_dim, dropout\n", " )\n", " self.pre_head_ln = nn.LayerNorm(embedding_dim)\n", "\n", " if self.conv_patch_representation:\n", " self.conv_x = nn.Conv3d(128, self.embedding_dim, kernel_size=3, stride=1, padding=1)\n", " self.encoder = Encoder(self.num_channels, 16)\n", " self.bn = nn.InstanceNorm3d(128)\n", " self.relu = nn.LeakyReLU(0.2, inplace=True)\n", " self.FeatureMapping = FeatureMapping(in_channel = self.embedding_dim, out_channel= self.in_channels_vae)\n", " self.FeatureMapping1 = FeatureMapping1(in_channel = self.in_channels_vae)\n", " self.decoder = Decoder(self.img_dim, self.patch_dim, self.embedding_dim, num_classes)\n", "\n", " self.vae_input = (1, self.in_channels_vae, img_dim[0] // 8, img_dim[1] // 8, img_dim[2] // 8)\n", " if use_VAE:\n", " self.vae = VAE(input_shape = self.vae_input , latent_dim= 256, num_channels= self.num_channels)\n", " def encode(self, x):\n", " if self.conv_patch_representation:\n", " x1, x2, x3, x = self.encoder(x)\n", " x = self.bn(x)\n", " x = self.relu(x)\n", " x = self.conv_x(x)\n", " x = x.permute(0, 2, 3, 4, 1).contiguous()\n", " x = x.view(x.size(0), -1, self.embedding_dim)\n", " x = self.position_encoding(x)\n", " x = self.pe_dropout(x)\n", " x = self.transformer(x)\n", " x = self.pre_head_ln(x)\n", "\n", " return x1, x2, x3, x\n", "\n", " def decode(self, x1, x2, x3, x):\n", " #x: (1, 4096, 512) -> (1, 16, 16, 16, 512)\n", "# print(\"In decode...\")\n", "# print(\" x1: {} \\n x2: {} \\n x3: {} \\n x: {}\".format( x1.shape, x2.shape, x3.shape, x.shape))\n", "# break\n", " return self.decoder(x1, x2, x3, x)\n", "\n", " def forward(self, x, is_validation = True):\n", " x1, x2, x3, x = self.encode(x)\n", " x = x.view( x.size(0),\n", " self.img_dim[0] // self.patch_dim,\n", " self.img_dim[1] // self.patch_dim,\n", " self.img_dim[2] // self.patch_dim,\n", " self.embedding_dim)\n", " x = x.permute(0, 4, 1, 2, 3).contiguous()\n", " x = self.FeatureMapping(x)\n", " x = self.FeatureMapping1(x)\n", " if self.use_VAE and not is_validation:\n", " vae_out, mu, sigma = self.vae(x)\n", " y = self.decode(x1, x2, x3, x)\n", " if self.use_VAE and not is_validation:\n", " return y, vae_out, mu, sigma\n", " else:\n", " return y\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CUDA (GPU) is available. Using GPU.\n" ] } ], "source": [ "import torch\n", "\n", "# Check if CUDA (GPU support) is available\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda:0\")\n", " print(\"CUDA (GPU) is available. Using GPU.\")\n", "else:\n", " device = torch.device(\"cpu\")\n", " print(\"CUDA (GPU) is not available. Using CPU.\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "model = SegTransVAE(img_dim = (128, 128, 128),patch_dim= 8,num_channels =4,num_classes= 3,embedding_dim= 768,num_heads= 8,num_layers= 4, hidden_dim= 3072,in_channels_vae=128 , use_VAE = True)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tổng số tham số của mô hình là: 44727120\n", "Tổng số tham số cần tính gradient của mô hình là: 44727120\n" ] } ], "source": [ "total_params = sum(p.numel() for p in model.parameters())\n", "print(f'Tổng số tham số của mô hình là: {total_params}')\n", "\n", "total_params_requires_grad = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "print(f'Tổng số tham số cần tính gradient của mô hình là: {total_params_requires_grad}')\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "class Loss_VAE(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.mse = nn.MSELoss(reduction='sum')\n", "\n", " def forward(self, recon_x, x, mu, log_var):\n", " mse = self.mse(recon_x, x)\n", " kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())\n", " loss = mse + kld\n", " return loss" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def DiceScore(\n", " y_pred: torch.Tensor,\n", " y: torch.Tensor,\n", " include_background: bool = True,\n", ") -> torch.Tensor:\n", " \"\"\"Computes Dice score metric from full size Tensor and collects average.\n", " Args:\n", " y_pred: input data to compute, typical segmentation model output.\n", " It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values\n", " should be binarized.\n", " y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.\n", " The values should be binarized.\n", " include_background: whether to skip Dice computation on the first channel of\n", " the predicted output. Defaults to True.\n", " Returns:\n", " Dice scores per batch and per class, (shape [batch_size, num_classes]).\n", " Raises:\n", " ValueError: when `y_pred` and `y` have different shapes.\n", " \"\"\"\n", "\n", " y = y.float()\n", " y_pred = y_pred.float()\n", "\n", " if y.shape != y_pred.shape:\n", " raise ValueError(\"y_pred and y should have same shapes.\")\n", "\n", " # reducing only spatial dimensions (not batch nor channels)\n", " n_len = len(y_pred.shape)\n", " reduce_axis = list(range(2, n_len))\n", " intersection = torch.sum(y * y_pred, dim=reduce_axis)\n", "\n", " y_o = torch.sum(y, reduce_axis)\n", " y_pred_o = torch.sum(y_pred, dim=reduce_axis)\n", " denominator = y_o + y_pred_o\n", "\n", " return torch.where(\n", " denominator > 0,\n", " (2.0 * intersection) / denominator,\n", " torch.tensor(float(\"1\"), device=y_o.device),\n", " )\n" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# Pytorch Lightning\n", "import pytorch_lightning as pl\n", "import matplotlib.pyplot as plt\n", "import csv\n", "from monai.transforms import AsDiscrete, Activations, Compose, EnsureType" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "class BRATS(pl.LightningModule):\n", " def __init__(self, use_VAE = True, lr = 1e-4, ):\n", " super().__init__()\n", " \n", " self.use_vae = use_VAE\n", " self.lr = lr\n", " self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE)\n", "\n", " self.loss_vae = Loss_VAE()\n", " self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True)\n", " self.post_trans_images = Compose(\n", " [EnsureType(),\n", " Activations(sigmoid=True), \n", " AsDiscrete(threshold_values=True), \n", " ]\n", " )\n", "\n", " self.best_val_dice = 0\n", " \n", " self.training_step_outputs = [] \n", " self.val_step_loss = [] \n", " self.val_step_dice = []\n", " self.val_step_dice_tc = [] \n", " self.val_step_dice_wt = []\n", " self.val_step_dice_et = [] \n", " self.test_step_loss = [] \n", " self.test_step_dice = []\n", " self.test_step_dice_tc = [] \n", " self.test_step_dice_wt = []\n", " self.test_step_dice_et = [] \n", "\n", " def forward(self, x, is_validation = True):\n", " return self.model(x, is_validation) \n", " def training_step(self, batch, batch_index):\n", " inputs, labels = (batch['image'], batch['label'])\n", " \n", " if not self.use_vae:\n", " outputs = self.forward(inputs, is_validation=False)\n", " loss = self.dice_loss(outputs, labels)\n", " else:\n", " outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False)\n", " \n", " vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma)\n", " dice_loss = self.dice_loss(outputs, labels)\n", " loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss\n", " self.training_step_outputs.append(loss)\n", " self.log('train/vae_loss', vae_loss)\n", " self.log('train/dice_loss', dice_loss)\n", " if batch_index == 10:\n", "\n", " tensorboard = self.logger.experiment \n", " fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5))\n", " \n", "\n", " ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray')\n", " ax[0].set_title(\"Input\")\n", "\n", " ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray')\n", " ax[1].set_title(\"Reconstruction\")\n", " \n", " ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray')\n", " ax[2].set_title(\"Labels TC\")\n", " \n", " ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray')\n", " ax[3].set_title(\"TC\")\n", " \n", " ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray')\n", " ax[4].set_title(\"Labels ET\")\n", " \n", " ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray')\n", " ax[5].set_title(\"ET\")\n", "\n", " \n", " tensorboard.add_figure('train_visualize', fig, self.current_epoch)\n", "\n", " self.log('train/loss', loss)\n", " \n", " return loss\n", " \n", " def on_train_epoch_end(self):\n", " ## F1 Macro all epoch saving outputs and target per batch\n", "\n", " # free up the memory\n", " # --> HERE STEP 3 <--\n", " epoch_average = torch.stack(self.training_step_outputs).mean()\n", " self.log(\"training_epoch_average\", epoch_average)\n", " self.training_step_outputs.clear() # free memory\n", "\n", " def validation_step(self, batch, batch_index):\n", " inputs, labels = (batch['image'], batch['label'])\n", " roi_size = (128, 128, 128)\n", " sw_batch_size = 1\n", " outputs = sliding_window_inference(\n", " inputs, roi_size, sw_batch_size, self.model, overlap = 0.5)\n", " loss = self.dice_loss(outputs, labels)\n", " \n", " \n", " val_outputs = self.post_trans_images(outputs)\n", " \n", " \n", " metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)\n", " metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)\n", " metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)\n", " mean_val_dice = (metric_tc + metric_wt + metric_et)/3\n", " self.val_step_loss.append(loss) \n", " self.val_step_dice.append(mean_val_dice)\n", " self.val_step_dice_tc.append(metric_tc) \n", " self.val_step_dice_wt.append(metric_wt)\n", " self.val_step_dice_et.append(metric_et) \n", " return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc,\n", " 'val_dice_wt': metric_wt, 'val_dice_et': metric_et}\n", " \n", " def on_validation_epoch_end(self):\n", "\n", " loss = torch.stack(self.val_step_loss).mean()\n", " mean_val_dice = torch.stack(self.val_step_dice).mean()\n", " metric_tc = torch.stack(self.val_step_dice_tc).mean()\n", " metric_wt = torch.stack(self.val_step_dice_wt).mean()\n", " metric_et = torch.stack(self.val_step_dice_et).mean()\n", " self.log('val/Loss', loss)\n", " self.log('val/MeanDiceScore', mean_val_dice)\n", " self.log('val/DiceTC', metric_tc)\n", " self.log('val/DiceWT', metric_wt)\n", " self.log('val/DiceET', metric_et)\n", " os.makedirs(self.logger.log_dir, exist_ok=True)\n", " if self.current_epoch == 0:\n", " with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f:\n", " writer = csv.writer(f)\n", " writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET'])\n", " with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f:\n", " writer = csv.writer(f)\n", " writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()])\n", "\n", " if mean_val_dice > self.best_val_dice:\n", " self.best_val_dice = mean_val_dice\n", " self.best_val_epoch = self.current_epoch\n", " print(\n", " f\"\\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}\"\n", " f\" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}\"\n", " f\"\\n Best mean dice: {self.best_val_dice}\"\n", " f\" at epoch: {self.best_val_epoch}\"\n", " )\n", " \n", " self.val_step_loss.clear() \n", " self.val_step_dice.clear()\n", " self.val_step_dice_tc.clear() \n", " self.val_step_dice_wt.clear()\n", " self.val_step_dice_et.clear()\n", " return {'val_MeanDiceScore': mean_val_dice}\n", " def test_step(self, batch, batch_index):\n", " inputs, labels = (batch['image'], batch['label'])\n", " \n", " roi_size = (128, 128, 128)\n", " sw_batch_size = 1\n", " test_outputs = sliding_window_inference(\n", " inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5)\n", " loss = self.dice_loss(test_outputs, labels)\n", " test_outputs = self.post_trans_images(test_outputs)\n", " metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True)\n", " metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True)\n", " metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True)\n", " mean_test_dice = (metric_tc + metric_wt + metric_et)/3\n", " \n", " self.test_step_loss.append(loss) \n", " self.test_step_dice.append(mean_test_dice)\n", " self.test_step_dice_tc.append(metric_tc) \n", " self.test_step_dice_wt.append(metric_wt)\n", " self.test_step_dice_et.append(metric_et) \n", " \n", " return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc,\n", " 'test_dice_wt': metric_wt, 'test_dice_et': metric_et}\n", " \n", " def test_epoch_end(self):\n", " loss = torch.stack(self.test_step_loss).mean()\n", " mean_test_dice = torch.stack(self.test_step_dice).mean()\n", " metric_tc = torch.stack(self.test_step_dice_tc).mean()\n", " metric_wt = torch.stack(self.test_step_dice_wt).mean()\n", " metric_et = torch.stack(self.test_step_dice_et).mean()\n", " self.log('test/Loss', loss)\n", " self.log('test/MeanDiceScore', mean_test_dice)\n", " self.log('test/DiceTC', metric_tc)\n", " self.log('test/DiceWT', metric_wt)\n", " self.log('test/DiceET', metric_et)\n", "\n", " with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f:\n", " writer = csv.writer(f)\n", " writer.writerow([\"Mean Test Dice\", \"Dice TC\", \"Dice WT\", \"Dice ET\"])\n", " writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et])\n", "\n", " self.test_step_loss.clear() \n", " self.test_step_dice.clear()\n", " self.test_step_dice_tc.clear() \n", " self.test_step_dice_wt.clear()\n", " self.test_step_dice_et.clear()\n", " return {'test_MeanDiceScore': mean_test_dice}\n", " \n", " \n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(\n", " self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True\n", " )\n", "# optimizer = AdaBelief(self.model.parameters(), \n", "# lr=self.lr, eps=1e-16, \n", "# betas=(0.9,0.999), weight_decouple = True, \n", "# rectify = False)\n", " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)\n", " return [optimizer], [scheduler]\n", " \n", " def train_dataloader(self):\n", " return train_loader\n", " def val_dataloader(self):\n", " return val_loader\n", " \n", " def test_dataloader(self):\n", " return test_loader" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n", "import os \n", "from pytorch_lightning.loggers import TensorBoardLogger" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sh: 1: cls: not found\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[H\u001b[2JTraining ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.9/site-packages/lightning_fabric/connector.py:563: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!\n", "Using 16bit Automatic Mixed Precision (AMP)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "------------------------------------------\n", "0 | model | SegTransVAE | 44.7 M\n", "1 | loss_vae | Loss_VAE | 0 \n", "2 | dice_loss | DiceLoss | 0 \n", "------------------------------------------\n", "44.7 M Trainable params\n", "0 Non-trainable params\n", "44.7 M Total params\n", "178.908 Total estimated model params size (MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:05<00:00, 0.37it/s]\n", " Current epoch: 0 Current mean dice: 0.0097 tc: 0.0029 wt: 0.0234 et: 0.0028\n", " Best mean dice: 0.009687595069408417 at epoch: 0\n", "Epoch 0: 100%|██████████| 500/500 [05:38<00:00, 1.48it/s, v_num=6] \n", " Current epoch: 0 Current mean dice: 0.1927 tc: 0.1647 wt: 0.2843 et: 0.1290\n", " Best mean dice: 0.1926589012145996 at epoch: 0\n", "Epoch 1: 100%|██████████| 500/500 [07:35<00:00, 1.10it/s, v_num=6]\n", " Current epoch: 1 Current mean dice: 0.3212 tc: 0.2691 wt: 0.4253 et: 0.2692\n", " Best mean dice: 0.32120221853256226 at epoch: 1\n", "Epoch 2: 100%|██████████| 500/500 [08:11<00:00, 1.02it/s, v_num=6]\n", " Current epoch: 2 Current mean dice: 0.3912 tc: 0.3510 wt: 0.5087 et: 0.3137\n", " Best mean dice: 0.39115065336227417 at epoch: 2\n", "Epoch 3: 100%|██████████| 500/500 [08:58<00:00, 0.93it/s, v_num=6]\n", " Current epoch: 3 Current mean dice: 0.4268 tc: 0.3828 wt: 0.5424 et: 0.3553\n", " Best mean dice: 0.42682838439941406 at epoch: 3\n", "Epoch 4: 41%|████▏ | 207/500 [02:51<04:03, 1.21it/s, v_num=6]" ] }, { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", "\u001b[1;31mClick here for more info. \n", "\u001b[1;31mView Jupyter log for further details." ] } ], "source": [ "os.system('cls||clear')\n", "print(\"Training ...\")\n", "model = BRATS(use_VAE = True)\n", "checkpoint_callback = ModelCheckpoint(\n", " monitor='val/MeanDiceScore',\n", " dirpath='./app/checkpoints/{}'.format(1),\n", " filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',\n", " save_top_k=3,\n", " mode='max',\n", " save_last= True,\n", " auto_insert_metric_name=False\n", ")\n", "early_stop_callback = EarlyStopping(\n", " monitor='val/MeanDiceScore',\n", " min_delta=0.0001,\n", " patience=15,\n", " verbose=False,\n", " mode='max'\n", ")\n", "tensorboardlogger = TensorBoardLogger(\n", " 'logs', \n", " name = \"1\", \n", " default_hp_metric = None \n", ")\n", "trainer = pl.Trainer(#fast_dev_run = 10, \n", "# accelerator='ddp',\n", " #overfit_batches=5,\n", " devices = [0], \n", " precision=16,\n", " max_epochs = 200, \n", " enable_progress_bar=True, \n", " callbacks=[checkpoint_callback, early_stop_callback], \n", "# auto_lr_find=True,\n", " num_sanity_val_steps=2,\n", " logger = tensorboardlogger,\n", "# limit_train_batches=0.01, \n", "# limit_val_batches=0.01\n", " )\n", "# trainer.tune(model)\n", "trainer.fit(model)\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "from trainer import BRATS\n", "import os \n", "import torch\n", "os.system('cls||clear')\n", "print(\"Testing ...\")\n", "\n", "CKPT = ''\n", "model = BRATS(use_VAE=True).load_from_checkpoint(CKPT).eval()\n", "val_dataloader = get_val_dataloader()\n", "test_dataloader = get_test_dataloader()\n", "trainer = pl.Trainer(gpus = [0], precision=32, progress_bar_refresh_rate=10)\n", "\n", "trainer.test(model, dataloaders = val_dataloader)\n", "trainer.test(model, dataloaders = test_dataloader)\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 2 }