andreped commited on
Commit
9f284dc
1 Parent(s): 1155abd

Support for setting threshold through argparse

Browse files
.github/workflows/build.yml CHANGED
@@ -67,7 +67,4 @@ jobs:
67
  run: lungtumormask --help
68
 
69
  - name: Test inference
70
- run: lungtumormask samples/lung_001.nii.gz mask_001.nii.gz
71
-
72
- #- name: Test lungmask postprocessing
73
- # run: lungtumormask samples/lung_001.nii.gz mask_001.nii.gz --lung-filter
 
67
  run: lungtumormask --help
68
 
69
  - name: Test inference
70
+ run: lungtumormask samples/lung_001.nii.gz mask_001.nii.gz --threshold 0.3 --lung-filter
 
 
 
lungtumormask/__main__.py CHANGED
@@ -13,10 +13,11 @@ def main():
13
  parser = argparse.ArgumentParser()
14
  parser.add_argument('input', metavar='input', type=path, help='Path to the input image, should be .nifti')
15
  parser.add_argument('output', metavar='output', type=str, help='Filepath for output tumormask')
16
- parser.add_argument('--lung-filter', action='store_true',
17
- help="whether to apply lungmask postprocessing.")
 
18
 
19
  argsin = sys.argv[1:]
20
  args = parser.parse_args(argsin)
21
 
22
- mask.mask(args.input, args.output, args.lung_filter)
 
13
  parser = argparse.ArgumentParser()
14
  parser.add_argument('input', metavar='input', type=path, help='Path to the input image, should be .nifti')
15
  parser.add_argument('output', metavar='output', type=str, help='Filepath for output tumormask')
16
+ parser.add_argument('--lung-filter', action='store_true', help='whether to apply lungmask postprocessing.')
17
+ parser.add_argument('--threshold', metavar='threshold', type=int, default=0.4,
18
+ help='which threshold to use for assigning voxel-wise classes.')
19
 
20
  argsin = sys.argv[1:]
21
  args = parser.parse_args(argsin)
22
 
23
+ mask.mask(args.input, args.output, args.lung_filter, args.threshold)
lungtumormask/dataprocessing.py CHANGED
@@ -211,12 +211,14 @@ def find_pad_edge(original):
211
 
212
  def remove_pad(mask, original):
213
  a_min, a_max, b_min, b_max, c_min, c_max = find_pad_edge(original)
 
214
  return mask[a_min:a_max, b_min:b_max, c_min: c_max]
215
 
216
  def voxel_space(image, target):
217
  image = Resize((target[0][1]-target[0][0], target[1][1]-target[1][0], target[2][1]-target[2][0]), mode='trilinear')(np.expand_dims(image, 0))[0]
218
  image = ThresholdIntensity(above = False, threshold = 0.5, cval = 1)(image)
219
  image = ThresholdIntensity(above = True, threshold = 0.5, cval = 0)(image)
 
220
  return image
221
 
222
  def stitch(org_shape, cropped, roi):
@@ -225,9 +227,9 @@ def stitch(org_shape, cropped, roi):
225
 
226
  return holder
227
 
228
- def post_process(left_mask, right_mask, preprocess_dump, lung_filter):
229
- left_mask = (left_mask >= 0.5).astype(int)
230
- right_mask = (right_mask >= 0.5).astype(int)
231
 
232
  left = remove_pad(left_mask, preprocess_dump['left_lung'].squeeze(0).squeeze(0).numpy())
233
  right = remove_pad(right_mask, preprocess_dump['right_lung'].squeeze(0).squeeze(0).numpy())
 
211
 
212
  def remove_pad(mask, original):
213
  a_min, a_max, b_min, b_max, c_min, c_max = find_pad_edge(original)
214
+
215
  return mask[a_min:a_max, b_min:b_max, c_min: c_max]
216
 
217
  def voxel_space(image, target):
218
  image = Resize((target[0][1]-target[0][0], target[1][1]-target[1][0], target[2][1]-target[2][0]), mode='trilinear')(np.expand_dims(image, 0))[0]
219
  image = ThresholdIntensity(above = False, threshold = 0.5, cval = 1)(image)
220
  image = ThresholdIntensity(above = True, threshold = 0.5, cval = 0)(image)
221
+
222
  return image
223
 
224
  def stitch(org_shape, cropped, roi):
 
227
 
228
  return holder
229
 
230
+ def post_process(left_mask, right_mask, preprocess_dump, lung_filter, threshold):
231
+ left_mask = (left_mask >= threshold).astype(int)
232
+ right_mask = (right_mask >= threshold).astype(int)
233
 
234
  left = remove_pad(left_mask, preprocess_dump['left_lung'].squeeze(0).squeeze(0).numpy())
235
  right = remove_pad(right_mask, preprocess_dump['right_lung'].squeeze(0).squeeze(0).numpy())
lungtumormask/mask.py CHANGED
@@ -15,7 +15,7 @@ def load_model():
15
  model.eval()
16
  return model
17
 
18
- def mask(image_path, save_path, lung_filter):
19
  print("Loading model...")
20
  model = load_model()
21
 
@@ -27,7 +27,7 @@ def mask(image_path, save_path, lung_filter):
27
  right = model(preprocess_dump['right_lung']).squeeze(0).squeeze(0).detach().numpy()
28
 
29
  print("Post-processing image...")
30
- inferred = post_process(left, right, preprocess_dump, lung_filter).astype("uint8")
31
 
32
  print(f"Storing segmentation at {save_path}")
33
  nimage = nibabel.Nifti1Image(inferred, preprocess_dump['org_affine'])
 
15
  model.eval()
16
  return model
17
 
18
+ def mask(image_path, save_path, lung_filter, threshold):
19
  print("Loading model...")
20
  model = load_model()
21
 
 
27
  right = model(preprocess_dump['right_lung']).squeeze(0).squeeze(0).detach().numpy()
28
 
29
  print("Post-processing image...")
30
+ inferred = post_process(left, right, preprocess_dump, lung_filter, threshold).astype("uint8")
31
 
32
  print(f"Storing segmentation at {save_path}")
33
  nimage = nibabel.Nifti1Image(inferred, preprocess_dump['org_affine'])