Nicolas Draber commited on
Commit
6d061ac
1 Parent(s): ad60bff

Import space pipeline file

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import streamlit as st
5
+ import tensorflow as tf
6
+ from tensorflow.keras.models import load_model
7
+
8
+ # most of this code has been obtained from Datature's prediction script
9
+ # https://github.com/datature/resources/blob/main/scripts/bounding_box/prediction.py
10
+
11
+ st.set_option('deprecation.showfileUploaderEncoding', False)
12
+
13
+ @st.cache(allow_output_mutation=True)
14
+ def load_model():
15
+ return tf.saved_model.load('./saved_model')
16
+
17
+ def load_label_map(label_map_path):
18
+ """
19
+ Reads label map in the format of .pbtxt and parse into dictionary
20
+ Args:
21
+ label_map_path: the file path to the label_map
22
+ Returns:
23
+ dictionary with the format of {label_index: {'id': label_index, 'name': label_name}}
24
+ """
25
+ label_map = {}
26
+
27
+ with open(label_map_path, "r") as label_file:
28
+ for line in label_file:
29
+ if "id" in line:
30
+ label_index = int(line.split(":")[-1])
31
+ label_name = next(label_file).split(":")[-1].strip().strip('"')
32
+ label_map[label_index] = {"id": label_index, "name": label_name}
33
+ return label_map
34
+
35
+ def predict_class(image, model):
36
+ image = tf.cast(image, tf.float32)
37
+ image = tf.image.resize(image, [150, 150])
38
+ image = np.expand_dims(image, axis = 0)
39
+ return model.predict(image)
40
+
41
+ def plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape):
42
+ for idx, each_bbox in enumerate(bboxes):
43
+ color = color_map[classes[idx]]
44
+
45
+ ## Draw bounding box
46
+ cv2.rectangle(
47
+ image_origi,
48
+ (int(each_bbox[1] * origi_shape[1]),
49
+ int(each_bbox[0] * origi_shape[0]),),
50
+ (int(each_bbox[3] * origi_shape[1]),
51
+ int(each_bbox[2] * origi_shape[0]),),
52
+ color,
53
+ 2,
54
+ )
55
+ ## Draw label background
56
+ cv2.rectangle(
57
+ image_origi,
58
+ (int(each_bbox[1] * origi_shape[1]),
59
+ int(each_bbox[2] * origi_shape[0]),),
60
+ (int(each_bbox[3] * origi_shape[1]),
61
+ int(each_bbox[2] * origi_shape[0] + 15),),
62
+ color,
63
+ -1,
64
+ )
65
+ ## Insert label class & score
66
+ cv2.putText(
67
+ image_origi,
68
+ "Class: {}, Score: {}".format(
69
+ str(category_index[classes[idx]]["name"]),
70
+ str(round(scores[idx], 2)),
71
+ ),
72
+ (int(each_bbox[1] * origi_shape[1]),
73
+ int(each_bbox[2] * origi_shape[0] + 10),),
74
+ cv2.FONT_HERSHEY_SIMPLEX,
75
+ 0.3,
76
+ (0, 0, 0),
77
+ 1,
78
+ cv2.LINE_AA,
79
+ )
80
+ return image_origi
81
+
82
+
83
+ # Webpage code starts here
84
+
85
+ #TODO change this
86
+ st.title('Distribution Grid - Belgium - Equipment detection')
87
+ st.text('made by LabelFlow')
88
+ st.markdown('## Description about your project')
89
+
90
+ with st.spinner('Model is being loaded...'):
91
+ model = load_model()
92
+
93
+ # ask user to upload an image
94
+ file = st.file_uploader("Upload image", type=["jpg", "png"])
95
+
96
+ if file is None:
97
+ st.text('Waiting for upload...')
98
+ else:
99
+ st.text('Running inference...')
100
+ # open image
101
+ test_image = Image.open(file).convert("RGB")
102
+ origi_shape = np.asarray(test_image).shape
103
+ # resize image to default shape
104
+ default_shape = 320
105
+ image_resized = np.array(test_image.resize((default_shape, default_shape)))
106
+
107
+ ## Load color map
108
+ category_index = load_label_map("./label_map.pbtxt")
109
+
110
+ # TODO Add more colors if there are more classes
111
+ # color of each label. check label_map.pbtxt to check the index for each class
112
+ color_map = {
113
+ 1: [69, 109, 42],
114
+ 2: [107, 46, 186],
115
+ 3: [9, 35, 183],
116
+ 4: [27, 1, 30],
117
+ 5: [0, 0, 0],
118
+ 6: [5, 6, 7],
119
+ 7: [11, 5, 12],
120
+ 8: [209, 205, 211],
121
+ 9: [17, 17, 17],
122
+ 10: [101, 242, 50],
123
+ 11: [51, 204, 170],
124
+ 12: [106, 0, 132],
125
+ 13: [7, 111, 153],
126
+ 14: [8, 10, 9],
127
+ 15: [234, 250, 252],
128
+ 16: [58, 68, 30],
129
+ 17: [24, 178, 117],
130
+ 18: [21, 22, 21],
131
+ 19: [53, 104, 83],
132
+ 20: [12, 5, 10],
133
+ 21: [223, 192, 249],
134
+ 22: [234, 234, 234],
135
+ 23: [119, 68, 221],
136
+ 24: [224, 174, 94],
137
+ 25: [140, 74, 116],
138
+ 26: [90, 102, 1],
139
+ 27: [216, 143, 208]
140
+ }
141
+
142
+ ## The model input needs to be a tensor
143
+ input_tensor = tf.convert_to_tensor(image_resized)
144
+ ## The model expects a batch of images, so add an axis with `tf.newaxis`.
145
+ input_tensor = input_tensor[tf.newaxis, ...]
146
+
147
+ ## Feed image into model and obtain output
148
+ detections_output = model(input_tensor)
149
+ num_detections = int(detections_output.pop("num_detections"))
150
+ detections = {key: value[0, :num_detections].numpy() for key, value in detections_output.items()}
151
+ detections["num_detections"] = num_detections
152
+
153
+ ## Filter out predictions below threshold
154
+ # if threshold is higher, there will be fewer predictions
155
+ # TODO change this number to see how the predictions change
156
+ confidence_threshold = 0.6
157
+ indexes = np.where(detections["detection_scores"] > confidence_threshold)
158
+
159
+ ## Extract predicted bounding boxes
160
+ bboxes = detections["detection_boxes"][indexes]
161
+ # there are no predicted boxes
162
+ if len(bboxes) == 0:
163
+ st.error('No boxes predicted')
164
+ # there are predicted boxes
165
+ else:
166
+ st.success('Boxes predicted')
167
+ classes = detections["detection_classes"][indexes].astype(np.int64)
168
+ scores = detections["detection_scores"][indexes]
169
+
170
+ # plot boxes and labels on image
171
+ image_origi = np.array(Image.fromarray(image_resized).resize((origi_shape[1], origi_shape[0])))
172
+ image_origi = plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape)
173
+
174
+ # show image in web page
175
+ st.image(Image.fromarray(image_origi), caption="Image with predictions", width=400)
176
+ st.markdown("### Predicted boxes")
177
+ for idx in range(len((bboxes))):
178
+ st.markdown(f"* Class: {str(category_index[classes[idx]]['name'])}, confidence score: {str(round(scores[idx], 2))}")