andreped commited on
Commit
72623c0
1 Parent(s): 13593ad

enable dynamic selection of CUDA GPU only if available, else GPU inference

Browse files
lungtumormask/__main__.py CHANGED
@@ -17,4 +17,4 @@ def main():
17
  argsin = sys.argv[1:]
18
  args = parser.parse_args(argsin)
19
 
20
- mask.mask(args.input, args.output)
 
17
  argsin = sys.argv[1:]
18
  args = parser.parse_args(argsin)
19
 
20
+ mask.mask(args.input, args.output)
lungtumormask/dataprocessing.py CHANGED
@@ -10,7 +10,8 @@ from monai.transforms import (Compose, LoadImaged, ToNumpyd, ThresholdIntensityd
10
 
11
  def mask_lung(scan_path, batch_size=20):
12
  model = lungmask.mask.get_model('unet', 'R231')
13
- device = torch.device('cuda')
 
14
  model.to(device)
15
 
16
  scan_dict = {
 
10
 
11
  def mask_lung(scan_path, batch_size=20):
12
  model = lungmask.mask.get_model('unet', 'R231')
13
+ if T.cuda.is_available():
14
+ device = torch.device('cuda')
15
  model.to(device)
16
 
17
  scan_dict = {
lungtumormask/mask.py CHANGED
@@ -5,10 +5,14 @@ import torch as T
5
  import nibabel
6
 
7
  def load_model():
 
 
 
 
8
  model = UNet_double(3, 1, 1, tuple([64, 128, 256, 512, 1024]), tuple([2 for i in range(4)]), num_res_units = 0)
9
- state_dict = T.hub.load_state_dict_from_url("https://github.com/VemundFredriksen/LungTumorMask/releases/download/0.0/dc_student.pth", progress=True, map_location=T.device('cuda:0'))
10
  #model.load_state_dict(T.load("D:\\OneDrive\\Skole\\Universitet\\10. Semester\\Masteroppgave\\bruk_for_full_model.pth", map_location="cuda:0"))
11
- model.load_state_dict(state_dict)
12
  model.eval()
13
  return model
14
 
 
5
  import nibabel
6
 
7
  def load_model():
8
+ if T.cuda.is_available():
9
+ gpu_device = T.device('cuda')
10
+ else:
11
+ gpu_device = T.device('cpu')
12
  model = UNet_double(3, 1, 1, tuple([64, 128, 256, 512, 1024]), tuple([2 for i in range(4)]), num_res_units = 0)
13
+ state_dict = T.hub.load_state_dict_from_url("https://github.com/VemundFredriksen/LungTumorMask/releases/download/0.0/dc_student.pth", progress=True, map_location=gpu_device)
14
  #model.load_state_dict(T.load("D:\\OneDrive\\Skole\\Universitet\\10. Semester\\Masteroppgave\\bruk_for_full_model.pth", map_location="cuda:0"))
15
+ model.load_state_dict(state_dict)
16
  model.eval()
17
  return model
18