import numpy as np from monai.transforms import MapTransform class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): """ Convert labels to multi channels based on brats classes: label 1 is the necrotic and non-enhancing tumor core label 2 is the peritumoral edema label 4 is the GD-enhancing tumor The possible classes are TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor). """ def __call__(self, data): d = dict(data) for key in self.keys: result = [] # merge label 1 and label 4 to construct TC result.append(np.logical_or(d[key] == 1, d[key] == 4)) # merge labels 1, 2 and 4 to construct WT result.append( np.logical_or( np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2 ) ) # label 4 is ET result.append(d[key] == 4) d[key] = np.stack(result, axis=0).astype(np.float32) return d