HarborYuan commited on
Commit
9cb6e3b
1 Parent(s): 4a2e688

bugfix state logic

Browse files
Files changed (1) hide show
  1. main.py +37 -18
main.py CHANGED
@@ -84,8 +84,6 @@ IMG_SIZE = 1024
84
 
85
 
86
  def get_points_with_draw(image, img_state, evt: gr.SelectData):
87
- w, h = image.size
88
- assert max(w, h) == IMG_SIZE, f"{w} x {h}"
89
  label = 'Add Mask'
90
 
91
  x, y = evt.index[0], evt.index[1]
@@ -143,23 +141,32 @@ def segment_with_points(
143
  image,
144
  img_state,
145
  ):
146
- assert not img_state.available
 
147
  output_img = img_state.img
148
  h, w = output_img.shape[:2]
149
 
150
  input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device)
151
-
152
  prompts = InstanceData(
153
  point_coords=input_points[None],
154
  )
155
- masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
156
 
157
- masks = masks[0, 0, :h, :w]
158
- masks = masks > 0.5
 
 
 
159
 
160
- cls_pred = cls_pred[0][0]
161
- scores, indices = torch.topk(cls_pred, 1)
162
- scores, indices = scores.tolist(), indices.tolist()
 
 
 
 
 
 
 
163
  names = []
164
  for ind in indices:
165
  names.append(LVIS_NAMES[ind].replace('_', ' '))
@@ -182,7 +189,8 @@ def segment_with_bbox(
182
  image,
183
  img_state
184
  ):
185
- assert not img_state.available
 
186
  if len(img_state.selected_bboxes) != 2:
187
  return image, None, ""
188
  output_img = img_state.img
@@ -196,18 +204,26 @@ def segment_with_bbox(
196
  max(box_points[0][1], box_points[1][1]),
197
  )
198
  input_bbox = torch.tensor(bbox, dtype=torch.float32, device=device)
199
-
200
  prompts = InstanceData(
201
  bboxes=input_bbox[None],
202
  )
203
- masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
204
 
205
- masks = masks[0, 0, :h, :w]
206
- masks = masks > 0.5
 
 
 
207
 
208
- cls_pred = cls_pred[0][0]
209
- scores, indices = torch.topk(cls_pred, 1)
210
- scores, indices = scores.tolist(), indices.tolist()
 
 
 
 
 
 
 
211
  names = []
212
  for ind in indices:
213
  names.append(LVIS_NAMES[ind].replace('_', ' '))
@@ -259,6 +275,9 @@ def clear_everything(img_state):
259
 
260
  def clean_prompts(img_state):
261
  img_state.clean()
 
 
 
262
  return Image.fromarray(img_state.img), None, "Please try to click something."
263
 
264
 
 
84
 
85
 
86
  def get_points_with_draw(image, img_state, evt: gr.SelectData):
 
 
87
  label = 'Add Mask'
88
 
89
  x, y = evt.index[0], evt.index[1]
 
141
  image,
142
  img_state,
143
  ):
144
+ if img_state.available:
145
+ return None, None, "State Error, please try again."
146
  output_img = img_state.img
147
  h, w = output_img.shape[:2]
148
 
149
  input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device)
 
150
  prompts = InstanceData(
151
  point_coords=input_points[None],
152
  )
 
153
 
154
+ try:
155
+ masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
156
+
157
+ masks = masks[0, 0, :h, :w]
158
+ masks = masks > 0.5
159
 
160
+ cls_pred = cls_pred[0][0]
161
+ scores, indices = torch.topk(cls_pred, 1)
162
+ scores, indices = scores.tolist(), indices.tolist()
163
+ except RuntimeError as e:
164
+ if "CUDA out of memory" in str(e):
165
+ img_state.clear()
166
+ print_log(f"CUDA OOM! please try again later", logger='current')
167
+ return None, None, "CUDA OOM, please try again later."
168
+ else:
169
+ raise
170
  names = []
171
  for ind in indices:
172
  names.append(LVIS_NAMES[ind].replace('_', ' '))
 
189
  image,
190
  img_state
191
  ):
192
+ if img_state.available:
193
+ return None, None, "State Error, please try again."
194
  if len(img_state.selected_bboxes) != 2:
195
  return image, None, ""
196
  output_img = img_state.img
 
204
  max(box_points[0][1], box_points[1][1]),
205
  )
206
  input_bbox = torch.tensor(bbox, dtype=torch.float32, device=device)
 
207
  prompts = InstanceData(
208
  bboxes=input_bbox[None],
209
  )
 
210
 
211
+ try:
212
+ masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
213
+
214
+ masks = masks[0, 0, :h, :w]
215
+ masks = masks > 0.5
216
 
217
+ cls_pred = cls_pred[0][0]
218
+ scores, indices = torch.topk(cls_pred, 1)
219
+ scores, indices = scores.tolist(), indices.tolist()
220
+ except RuntimeError as e:
221
+ if "CUDA out of memory" in str(e):
222
+ img_state.clear()
223
+ print_log(f"CUDA OOM! please try again later", logger='current')
224
+ return None, None, "CUDA OOM, please try again later."
225
+ else:
226
+ raise
227
  names = []
228
  for ind in indices:
229
  names.append(LVIS_NAMES[ind].replace('_', ' '))
 
275
 
276
  def clean_prompts(img_state):
277
  img_state.clean()
278
+ if img_state.img is None:
279
+ img_state.clear()
280
+ return None, None, "Please try to click something."
281
  return Image.fromarray(img_state.img), None, "Please try to click something."
282
 
283