Spaces:
Sleeping
Sleeping
enable dynamic selection of CUDA GPU only if available, else GPU inference
Browse files- lungtumormask/__main__.py +1 -1
- lungtumormask/dataprocessing.py +2 -1
- lungtumormask/mask.py +6 -2
lungtumormask/__main__.py
CHANGED
@@ -17,4 +17,4 @@ def main():
|
|
17 |
argsin = sys.argv[1:]
|
18 |
args = parser.parse_args(argsin)
|
19 |
|
20 |
-
mask.mask(args.input, args.output)
|
|
|
17 |
argsin = sys.argv[1:]
|
18 |
args = parser.parse_args(argsin)
|
19 |
|
20 |
+
mask.mask(args.input, args.output)
|
lungtumormask/dataprocessing.py
CHANGED
@@ -10,7 +10,8 @@ from monai.transforms import (Compose, LoadImaged, ToNumpyd, ThresholdIntensityd
|
|
10 |
|
11 |
def mask_lung(scan_path, batch_size=20):
|
12 |
model = lungmask.mask.get_model('unet', 'R231')
|
13 |
-
|
|
|
14 |
model.to(device)
|
15 |
|
16 |
scan_dict = {
|
|
|
10 |
|
11 |
def mask_lung(scan_path, batch_size=20):
|
12 |
model = lungmask.mask.get_model('unet', 'R231')
|
13 |
+
if T.cuda.is_available():
|
14 |
+
device = torch.device('cuda')
|
15 |
model.to(device)
|
16 |
|
17 |
scan_dict = {
|
lungtumormask/mask.py
CHANGED
@@ -5,10 +5,14 @@ import torch as T
|
|
5 |
import nibabel
|
6 |
|
7 |
def load_model():
|
|
|
|
|
|
|
|
|
8 |
model = UNet_double(3, 1, 1, tuple([64, 128, 256, 512, 1024]), tuple([2 for i in range(4)]), num_res_units = 0)
|
9 |
-
state_dict = T.hub.load_state_dict_from_url("https://github.com/VemundFredriksen/LungTumorMask/releases/download/0.0/dc_student.pth", progress=True, map_location=
|
10 |
#model.load_state_dict(T.load("D:\\OneDrive\\Skole\\Universitet\\10. Semester\\Masteroppgave\\bruk_for_full_model.pth", map_location="cuda:0"))
|
11 |
-
model.load_state_dict(state_dict)
|
12 |
model.eval()
|
13 |
return model
|
14 |
|
|
|
5 |
import nibabel
|
6 |
|
7 |
def load_model():
|
8 |
+
if T.cuda.is_available():
|
9 |
+
gpu_device = T.device('cuda')
|
10 |
+
else:
|
11 |
+
gpu_device = T.device('cpu')
|
12 |
model = UNet_double(3, 1, 1, tuple([64, 128, 256, 512, 1024]), tuple([2 for i in range(4)]), num_res_units = 0)
|
13 |
+
state_dict = T.hub.load_state_dict_from_url("https://github.com/VemundFredriksen/LungTumorMask/releases/download/0.0/dc_student.pth", progress=True, map_location=gpu_device)
|
14 |
#model.load_state_dict(T.load("D:\\OneDrive\\Skole\\Universitet\\10. Semester\\Masteroppgave\\bruk_for_full_model.pth", map_location="cuda:0"))
|
15 |
+
model.load_state_dict(state_dict)
|
16 |
model.eval()
|
17 |
return model
|
18 |
|