andreped commited on
Commit
5c51ac6
1 Parent(s): 2a35518

Added support to state whether to filter lungs or not - default no

Browse files
lungtumormask/__main__.py CHANGED
@@ -13,8 +13,10 @@ def main():
13
  parser = argparse.ArgumentParser()
14
  parser.add_argument('input', metavar='input', type=path, help='Path to the input image, should be .nifti')
15
  parser.add_argument('output', metavar='output', type=str, help='Filepath for output tumormask')
 
 
16
 
17
  argsin = sys.argv[1:]
18
  args = parser.parse_args(argsin)
19
 
20
- mask.mask(args.input, args.output)
 
13
  parser = argparse.ArgumentParser()
14
  parser.add_argument('input', metavar='input', type=path, help='Path to the input image, should be .nifti')
15
  parser.add_argument('output', metavar='output', type=str, help='Filepath for output tumormask')
16
+ parser.add_argment('lung-filter', metavar='post-process', action='store_true',
17
+ help="whether to apply lungmask postprocessing.")
18
 
19
  argsin = sys.argv[1:]
20
  args = parser.parse_args(argsin)
21
 
22
+ mask.mask(args.input, args.output, args.lung_filter)
lungtumormask/dataprocessing.py CHANGED
@@ -217,17 +217,15 @@ def voxel_space(image, target):
217
  image = Resize((target[0][1]-target[0][0], target[1][1]-target[1][0], target[2][1]-target[2][0]), mode='trilinear')(np.expand_dims(image, 0))[0]
218
  image = ThresholdIntensity(above = False, threshold = 0.5, cval = 1)(image)
219
  image = ThresholdIntensity(above = True, threshold = 0.5, cval = 0)(image)
220
-
221
  return image
222
 
223
  def stitch(org_shape, cropped, roi):
224
  holder = np.zeros(org_shape)
225
-
226
  holder[roi[0][0]:roi[0][1], roi[1][0]:roi[1][1], roi[2][0]:roi[2][1]] = cropped
227
 
228
  return holder
229
 
230
- def post_process(left_mask, right_mask, preprocess_dump):
231
  left_mask = (left_mask >= 0.5).astype(int)
232
  right_mask = (right_mask >= 0.5).astype(int)
233
 
@@ -243,6 +241,7 @@ def post_process(left_mask, right_mask, preprocess_dump):
243
  stitched = np.logical_or(left, right).astype(int)
244
 
245
  # filter tumor predictions outside the predicted lung area
246
- stitched[preprocess_dump['lungmask'] == 0] = 0
 
247
 
248
  return stitched
 
217
  image = Resize((target[0][1]-target[0][0], target[1][1]-target[1][0], target[2][1]-target[2][0]), mode='trilinear')(np.expand_dims(image, 0))[0]
218
  image = ThresholdIntensity(above = False, threshold = 0.5, cval = 1)(image)
219
  image = ThresholdIntensity(above = True, threshold = 0.5, cval = 0)(image)
 
220
  return image
221
 
222
  def stitch(org_shape, cropped, roi):
223
  holder = np.zeros(org_shape)
 
224
  holder[roi[0][0]:roi[0][1], roi[1][0]:roi[1][1], roi[2][0]:roi[2][1]] = cropped
225
 
226
  return holder
227
 
228
+ def post_process(left_mask, right_mask, preprocess_dump, lung_filter):
229
  left_mask = (left_mask >= 0.5).astype(int)
230
  right_mask = (right_mask >= 0.5).astype(int)
231
 
 
241
  stitched = np.logical_or(left, right).astype(int)
242
 
243
  # filter tumor predictions outside the predicted lung area
244
+ if lung_filter:
245
+ stitched[preprocess_dump['lungmask'] == 0] = 0
246
 
247
  return stitched
lungtumormask/mask.py CHANGED
@@ -15,7 +15,7 @@ def load_model():
15
  model.eval()
16
  return model
17
 
18
- def mask(image_path, save_path):
19
  print("Loading model...")
20
  model = load_model()
21
 
@@ -27,7 +27,7 @@ def mask(image_path, save_path):
27
  right = model(preprocess_dump['right_lung']).squeeze(0).squeeze(0).detach().numpy()
28
 
29
  print("Post-processing image...")
30
- inferred = post_process(left, right, preprocess_dump).astype("uint8")
31
 
32
  print(f"Storing segmentation at {save_path}")
33
  nimage = nibabel.Nifti1Image(inferred, preprocess_dump['org_affine'])
 
15
  model.eval()
16
  return model
17
 
18
+ def mask(image_path, save_path, lung_filter):
19
  print("Loading model...")
20
  model = load_model()
21
 
 
27
  right = model(preprocess_dump['right_lung']).squeeze(0).squeeze(0).detach().numpy()
28
 
29
  print("Post-processing image...")
30
+ inferred = post_process(left, right, preprocess_dump, lung_filter).astype("uint8")
31
 
32
  print(f"Storing segmentation at {save_path}")
33
  nimage = nibabel.Nifti1Image(inferred, preprocess_dump['org_affine'])