import os from glob import glob import torch from PIL import Image from torch.utils.tensorboard import SummaryWriter import monai from monai.data import ArrayDataset, decollate_batch, DataLoader from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric from monai.transforms import ( Activations, AsDiscrete, Compose, LoadImage, RandRotate90, ScaleIntensity, #AsChannelFirst ) from monai.visualize import plot_2d_or_3d_image from PIL import Image import cv2 import tifffile import os # Preprocess def convert_to_png(img_dir): # Lấy danh sách tệp tin trong thư mục ảnh img_files = [file for file in os.listdir(img_dir) if file.endswith(('.jpg', '.jpeg', '.png', '.tif'))] # Chuyển đổi từng ảnh sang định dạng .png for img_file in img_files: img_path = os.path.join(img_dir, img_file) if img_file.endswith('.tif'): # Đọc tệp .tif và chuyển đổi thành ảnh with tifffile.TiffFile(img_path) as tif: img = Image.fromarray(tif.asarray()) else: # Đọc ảnh từ các định dạng khác img = Image.open(img_path) # Lưu ảnh dưới dạng .png png_path = os.path.join(img_dir, os.path.splitext(img_file)[0] + '.png') img.save(png_path) convert_to_png("./tamp500/imgs") images = sorted(glob(os.path.join("./tamp500/imgs", "*.png"))) segs = sorted(glob(os.path.join("./tamp500/masks", "*.png"))) def resize_images_and_masks(image_paths, mask_paths, output_dir, target_width, target_height): """ Resize images and corresponding segmentation masks to the specified dimensions. Args: - image_paths (list): List of paths to the input images. - mask_paths (list): List of paths to the segmentation masks. - output_dir (str): Directory to save the resized images and masks. - target_width (int): Target width for resizing. - target_height (int): Target height for resizing. Returns: - resized_image_paths (list): List of paths to the resized images. - resized_mask_paths (list): List of paths to the resized segmentation masks. """ if not os.path.exists(output_dir): os.makedirs(output_dir) resized_image_dir = os.path.join(output_dir, 'resized_images') resized_mask_dir = os.path.join(output_dir, 'resized_masks') if not os.path.exists(resized_image_dir): os.makedirs(resized_image_dir) if not os.path.exists(resized_mask_dir): os.makedirs(resized_mask_dir) resized_image_paths = [] resized_mask_paths = [] for img_path, mask_path in zip(image_paths, mask_paths): # Read the image and mask img = cv2.imread(img_path) mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # Resize the image resized_img = cv2.resize(img, (target_width, target_height)) # Resize the mask resized_mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST) # Extract the filename from the image path img_filename = os.path.basename(img_path) # Construct the output image path output_img_path = os.path.join(resized_image_dir, img_filename) # Write the resized image to the output path cv2.imwrite(output_img_path, resized_img) resized_image_paths.append(output_img_path) # Extract the filename from the mask path mask_filename = os.path.basename(mask_path) # Construct the output mask path output_mask_path = os.path.join(resized_mask_dir, mask_filename) # Write the resized mask to the output path cv2.imwrite(output_mask_path, resized_mask) resized_mask_paths.append(output_mask_path) return resized_image_paths, resized_mask_paths images = sorted(glob(os.path.join("./tamp500/imgs", "*.png"))) masks = sorted(glob(os.path.join("./tamp500/masks", "*.png"))) output_directory = 'resized_' target_width = 448 target_height = 448 resized_image_paths, resized_mask_paths = resize_images_and_masks(images, masks, output_directory, target_width, target_height) images = sorted(glob(os.path.join("./resized_/resized_images", "*.png"))) segs = sorted(glob(os.path.join("./resized_/resized_masks", "*.png"))) from sklearn.model_selection import train_test_split train_images,test_images,train_segs,test_segs = train_test_split(images,segs,test_size = 0.2,random_state = 42) # define transforms for image and segmentation train_imtrans = Compose( [ LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity(), #RandSpatialCrop((224, 224), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ] ) train_segtrans = Compose( [ LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity(), #RandSpatialCrop((224, 224), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ] ) val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()]) val_segtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()]) # create a training data loader train_ds = ArrayDataset(train_images, train_imtrans, train_segs, train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = ArrayDataset(test_images, val_imtrans, test_segs, val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = monai.networks.nets.UNet( # spatial_dims=2, # in_channels=3, # out_channels=1, # channels=(16, 32, 64, 128, 256), # strides=(2, 2, 2, 2), # num_res_units=2, # ).to(device) model = monai.networks.nets.UNETR( spatial_dims=2, in_channels=3, out_channels=1, img_size =(448,448), #channels=(16, 32, 64, 128, 256), #strides=(2, 2, 2, 2), #num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(500): print("-" * 10) print(f"epoch {epoch + 1}/{500}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) roi_size = (448, 448) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( epoch + 1, metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()