import platform import nibabel import torch as T from monai.data import DataLoader, Dataset from monai.transforms import ( AddChanneld, Compose, LoadImaged, NormalizeIntensityd, Resized, ThresholdIntensityd, ToTensord ) from Engine.utils import ( load_model, read_yaml, store_output ) def infer(model, loader, device, infer_config): model.eval() with T.no_grad(): for image in loader: inp, boxes = image["image"].to(device), image["boxes"].to(device) directory_split = '\\' if platform.system() == 'Windows' else '/' output = model(inp) original_filename = image['image_meta_dict']['filename_or_obj'][0] segmentation_filename = f"{infer_config['save_directory']}{directory_split}seg_{original_filename.split(directory_split)[-1]}" original_image = nibabel.load(original_filename) upsample = T.nn.Upsample(original_image.shape, mode='trilinear', align_corners = False) output = upsample(output) print(image['image_meta_dict']['affine']) store_output(output, original_image, segmentation_filename, image['image_meta_dict']['affine'].squeeze(0).numpy()) def initiate(config_file): config = read_yaml(config_file) device = T.device(config["device"]) data = read_yaml(config["data"]["dataset"]) image_shape = (config["data"]["scale_dim"]["d_0"], config["data"]["scale_dim"]["d_1"], config["data"]["scale_dim"]["d_2"]) for i, d in enumerate(data['test']): data['test'][i]['image'] = data['image_prefix'] + d['image'] data['test'][i]['boxes'] = data['boxes_prefix'] + d['boxes'] transform = Compose( [ LoadImaged(keys=["image", "boxes"]), AddChanneld(keys=["image", "boxes"]), ToTensord(keys=["image", "boxes"]), ] ) dataset = Dataset(data['data'], transform) loader = T.utils.data.DataLoader(dataset, 1) model = load_model(config['model'], infer = True) model.load_state_dict(T.load(config["model"]["weights"])) model.to(device) logger.LogInfo("Starting inference!", [str(data)]) infer(model, loader, device, config['inference']) logger.LogMilestone("Inference finished!", [])