import json import os from sklearn.model_selection import train_test_split from monai.data import DataLoader, Dataset from monai import transforms def datafold_read(datalist, basedir, fold=0, key="training"): with open(datalist) as f: json_data = json.load(f) json_data = json_data[key] for d in json_data: for k in d: if isinstance(d[k], list): d[k] = [os.path.join(basedir, iv) for iv in d[k]] elif isinstance(d[k], str): d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k] tr = [] val = [] for d in json_data: if "fold" in d and d["fold"] == fold: val.append(d) else: tr.append(d) return tr, val def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) : train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold) if volume != None : train_files, _ = train_test_split(train_files,test_size=volume,random_state=42) train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42) validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42) return train_files, validation_files, test_files def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2): train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume) train_transform = transforms.Compose( [ transforms.LoadImaged(keys=["image", "label"]), transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), transforms.CropForegroundd( keys=["image", "label"], source_key="image", k_divisible=[roi[0], roi[1], roi[2]], ), transforms.RandSpatialCropd( keys=["image", "label"], roi_size=[roi[0], roi[1], roi[2]], random_size=False, ), transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0), transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0), ] ) val_transform = transforms.Compose( [ transforms.LoadImaged(keys=["image", "label"]), transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), ] ) train_ds = Dataset(data=train_files, transform=train_transform) train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, ) val_ds = Dataset(data=validation_files, transform=val_transform) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, ) test_ds = Dataset(data=test_files, transform=val_transform) test_loader = DataLoader( test_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, ) return train_loader, val_loader,test_loader