andreped commited on
Commit
7c1e417
1 Parent(s): 9026856

Setup demo app [no ci]

Browse files
demo/app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.gui import WebUI
2
+
3
+
4
+ def main():
5
+ print("Launching demo...")
6
+
7
+ # cwd = "/Users/andreped/workspace/LungTumorMask/" # local testing -> macOS
8
+ cwd = "/home/user/app/" # production -> docker
9
+
10
+ class_name = "tumor"
11
+
12
+ # initialize and run app
13
+ app = WebUI(model_name=model_name, class_name=class_name, cwd=cwd)
14
+ app.run()
15
+
16
+
17
+ if __name__ == "__main__":
18
+ main()
demo/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ lungtumormask @ git+https://github.com/vemundfredriksen/LungTumorMask.git
2
+ gradio==3.32.0
demo/src/__init__.py ADDED
File without changes
demo/src/compute.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ def run_model(input_path):
2
+ from lungtumormask import mask
3
+ mask.mask(input_path, "./output.nii.gz", lung_filter=True, threshold=0.5, radius=1, batch_size=1)
demo/src/convert.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nibabel as nib
2
+ from nibabel.processing import resample_to_output
3
+ from skimage.measure import marching_cubes
4
+
5
+
6
+ def nifti_to_glb(path, output="prediction.obj"):
7
+ # load NIFTI into numpy array
8
+ image = nib.load(path)
9
+ resampled = resample_to_output(image, [1, 1, 1], order=1)
10
+ data = resampled.get_fdata().astype("uint8")
11
+
12
+ # extract surface
13
+ verts, faces, normals, values = marching_cubes(data, 0)
14
+ faces += 1
15
+
16
+ with open(output, 'w') as thefile:
17
+ for item in verts:
18
+ thefile.write("v {0} {1} {2}\n".format(item[0],item[1],item[2]))
19
+
20
+ for item in normals:
21
+ thefile.write("vn {0} {1} {2}\n".format(item[0],item[1],item[2]))
22
+
23
+ for item in faces:
24
+ thefile.write("f {0}//{0} {1}//{1} {2}//{2}\n".format(item[0],item[1],item[2]))
demo/src/gui.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .utils import load_ct_to_numpy, load_pred_volume_to_numpy
3
+ from .compute import run_model
4
+ from .convert import nifti_to_glb
5
+
6
+
7
+ class WebUI:
8
+ def __init__(self, class_name:str = None, cwd:str = None):
9
+ # global states
10
+ self.images = []
11
+ self.pred_images = []
12
+
13
+ # @TODO: This should be dynamically set based on chosen volume size
14
+ self.nb_slider_items = 100
15
+
16
+ self.class_name = class_name
17
+ self.cwd = cwd
18
+
19
+ # define widgets not to be rendered immediantly, but later on
20
+ self.slider = gr.Slider(1, self.nb_slider_items, value=1, step=1, label="Which 2D slice to show")
21
+ self.volume_renderer = gr.Model3D(
22
+ clear_color=[0.0, 0.0, 0.0, 0.0],
23
+ label="3D Model",
24
+ visible=True,
25
+ elem_id="model-3d",
26
+ ).style(height=512)
27
+
28
+ def combine_ct_and_seg(self, img, pred):
29
+ return (img, [(pred, self.class_name)])
30
+
31
+ def upload_file(self, file):
32
+ return file.name
33
+
34
+ def load_mesh(self, mesh_file_name):
35
+ path = mesh_file_name.name
36
+ run_model(path)
37
+ nifti_to_glb("prediction-livermask.nii")
38
+ self.images = load_ct_to_numpy(path)
39
+ self.pred_images = load_pred_volume_to_numpy("./prediction-livermask.nii")
40
+ self.slider = self.slider.update(value=2)
41
+ return "./prediction.obj"
42
+
43
+ def get_img_pred_pair(self, k):
44
+ k = int(k) - 1
45
+ out = [gr.AnnotatedImage.update(visible=False)] * self.nb_slider_items
46
+ out[k] = gr.AnnotatedImage.update(self.combine_ct_and_seg(self.images[k], self.pred_images[k]), visible=True)
47
+ return out
48
+
49
+ def run(self):
50
+ css="""
51
+ #model-3d {
52
+ height: 512px;
53
+ }
54
+ #model-2d {
55
+ height: 512px;
56
+ margin: auto;
57
+ }
58
+ """
59
+ with gr.Blocks(css=css) as demo:
60
+
61
+ with gr.Row():
62
+ file_output = gr.File(
63
+ file_types=[".nii", ".nii.nz"],
64
+ file_count="single"
65
+ ).style(full_width=False, size="sm")
66
+ file_output.upload(self.upload_file, file_output, file_output)
67
+
68
+ run_btn = gr.Button("Run analysis").style(full_width=False, size="sm")
69
+ run_btn.click(
70
+ fn=lambda x: self.load_mesh(x),
71
+ inputs=file_output,
72
+ outputs=self.volume_renderer
73
+ )
74
+
75
+ with gr.Row():
76
+ gr.Examples(
77
+ examples=[self.cwd + "test-volume.nii"],
78
+ inputs=file_output,
79
+ outputs=file_output,
80
+ fn=self.upload_file,
81
+ cache_examples=True,
82
+ )
83
+
84
+ with gr.Row():
85
+ with gr.Box():
86
+ image_boxes = []
87
+ for i in range(self.nb_slider_items):
88
+ visibility = True if i == 1 else False
89
+ t = gr.AnnotatedImage(visible=visibility, elem_id="model-2d")\
90
+ .style(color_map={self.class_name: "#ffae00"}, height=512, width=512)
91
+ image_boxes.append(t)
92
+
93
+ self.slider.change(self.get_img_pred_pair, self.slider, image_boxes)
94
+
95
+ with gr.Box():
96
+ self.volume_renderer.render()
97
+
98
+ with gr.Row():
99
+ self.slider.render()
100
+
101
+ # sharing app publicly -> share=True: https://gradio.app/sharing-your-app/
102
+ # inference times > 60 seconds -> need queue(): https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
103
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
demo/src/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nibabel as nib
2
+ import numpy as np
3
+
4
+
5
+ def load_ct_to_numpy(data_path):
6
+ if type(data_path) != str:
7
+ data_path = data_path.name
8
+
9
+ image = nib.load(data_path)
10
+ data = image.get_fdata()
11
+
12
+ data = np.rot90(data, k=1, axes=(0, 1))
13
+
14
+ data[data < -150] = -150
15
+ data[data > 250] = 250
16
+
17
+ data = data - np.amin(data)
18
+ data = data / np.amax(data) * 255
19
+ data = data.astype("uint8")
20
+
21
+ print(data.shape)
22
+ return [data[..., i] for i in range(data.shape[-1])]
23
+
24
+
25
+ def load_pred_volume_to_numpy(data_path):
26
+ if type(data_path) != str:
27
+ data_path = data_path.name
28
+
29
+ image = nib.load(data_path)
30
+ data = image.get_fdata()
31
+
32
+ data = np.rot90(data, k=1, axes=(0, 1))
33
+
34
+ data[data > 0] = 1
35
+ data = data.astype("uint8")
36
+
37
+ print(data.shape)
38
+ return [data[..., i] for i in range(data.shape[-1])]