from pathlib import Path from typing import Any, Callable, Optional import numpy as np from PIL import Image from torchvision.datasets.vision import VisionDataset class SegmentationDataset(VisionDataset): """A PyTorch dataset for image segmentation task. The dataset is compatible with torchvision transforms. The transforms passed would be applied to both the Images and Masks. """ def __init__(self, root: str, image_folder: str, mask_folder: str, transforms: Optional[Callable] = None, seed: int = None, fraction: float = None, subset: str = None, image_color_mode: str = "rgb", mask_color_mode: str = "grayscale") -> None: """ Args: root (str): Root directory path. image_folder (str): Name of the folder that contains the images in the root directory. mask_folder (str): Name of the folder that contains the masks in the root directory. transforms (Optional[Callable], optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.ToTensor`` for images. Defaults to None. seed (int, optional): Specify a seed for the train and test split for reproducible results. Defaults to None. fraction (float, optional): A float value from 0 to 1 which specifies the validation split fraction. Defaults to None. subset (str, optional): 'Train' or 'Test' to select the appropriate set. Defaults to None. image_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'rgb'. mask_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'grayscale'. Raises: OSError: If image folder doesn't exist in root. OSError: If mask folder doesn't exist in root. ValueError: If subset is not either 'Train' or 'Test' ValueError: If image_color_mode and mask_color_mode are either 'rgb' or 'grayscale' """ super().__init__(root, transforms) image_folder_path = Path(self.root) / image_folder mask_folder_path = Path(self.root) / mask_folder if not image_folder_path.exists(): raise OSError(f"{image_folder_path} does not exist.") if not mask_folder_path.exists(): raise OSError(f"{mask_folder_path} does not exist.") if image_color_mode not in ["rgb", "grayscale"]: raise ValueError( f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale." ) if mask_color_mode not in ["rgb", "grayscale"]: raise ValueError( f"{mask_color_mode} is an invalid choice. Please enter from rgb grayscale." ) self.image_color_mode = image_color_mode self.mask_color_mode = mask_color_mode if not fraction: self.image_names = sorted(image_folder_path.glob("*")) self.mask_names = sorted(mask_folder_path.glob("*")) else: if subset not in ["Train", "Test"]: raise (ValueError( f"{subset} is not a valid input. Acceptable values are Train and Test." )) self.fraction = fraction self.image_list = np.array(sorted(image_folder_path.glob("*"))) self.mask_list = np.array(sorted(mask_folder_path.glob("*"))) if seed: np.random.seed(seed) indices = np.arange(len(self.image_list)) np.random.shuffle(indices) self.image_list = self.image_list[indices] self.mask_list = self.mask_list[indices] if subset == "Train": self.image_names = self.image_list[:int( np.ceil(len(self.image_list) * (1 - self.fraction)))] self.mask_names = self.mask_list[:int( np.ceil(len(self.mask_list) * (1 - self.fraction)))] else: self.image_names = self.image_list[ int(np.ceil(len(self.image_list) * (1 - self.fraction))):] self.mask_names = self.mask_list[ int(np.ceil(len(self.mask_list) * (1 - self.fraction))):] def __len__(self) -> int: return len(self.image_names) def __getitem__(self, index: int) -> Any: image_path = self.image_names[index] mask_path = self.mask_names[index] with open(image_path, "rb") as image_file, open(mask_path, "rb") as mask_file: image = Image.open(image_file) if self.image_color_mode == "rgb": image = image.convert("RGB") elif self.image_color_mode == "grayscale": image = image.convert("L") mask = Image.open(mask_file) if self.mask_color_mode == "rgb": mask = mask.convert("RGB") elif self.mask_color_mode == "grayscale": mask = mask.convert("L") sample = {"image": image, "mask": mask} if self.transforms: sample["image"] = self.transforms(sample["image"]) sample["mask"] = self.transforms(sample["mask"]) return sample