HarborYuan commited on
Commit
f1435cf
1 Parent(s): 9cb6e3b

add device switch

Browse files
Files changed (1) hide show
  1. main.py +14 -0
main.py CHANGED
@@ -75,6 +75,15 @@ class IMGState:
75
  self.selected_points_labels = []
76
  self.selected_bboxes = []
77
 
 
 
 
 
 
 
 
 
 
78
  @property
79
  def available(self):
80
  return self.available_to_set
@@ -152,7 +161,9 @@ def segment_with_points(
152
  )
153
 
154
  try:
 
155
  masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
 
156
 
157
  masks = masks[0, 0, :h, :w]
158
  masks = masks > 0.5
@@ -209,7 +220,9 @@ def segment_with_bbox(
209
  )
210
 
211
  try:
 
212
  masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
 
213
 
214
  masks = masks[0, 0, :h, :w]
215
  masks = masks > 0.5
@@ -257,6 +270,7 @@ def extract_img_feat(img, img_state):
257
  img_tensor = F.pad(img_tensor, (0, IMG_SIZE - new_w, 0, IMG_SIZE - new_h), 'constant', 0)
258
  feat_dict = model.extract_feat(img_tensor)
259
  img_state.set_img(img_numpy, feat_dict)
 
260
  print_log(f"Successfully generated the image feats.", logger='current')
261
  except RuntimeError as e:
262
  if "CUDA out of memory" in str(e):
 
75
  self.selected_points_labels = []
76
  self.selected_bboxes = []
77
 
78
+ def to_device(self, device=device):
79
+ if self.img_feat is not None:
80
+ for k in self.img_feat:
81
+ if isinstance(self.img_feat[k], torch.Tensor):
82
+ self.img_feat[k] = self.img_feat[k].to(device)
83
+ else:
84
+ for i in range(len(self.img_feat[k])):
85
+ self.img_feat[k][i] = self.img_feat[k][i].to(device)
86
+
87
  @property
88
  def available(self):
89
  return self.available_to_set
 
161
  )
162
 
163
  try:
164
+ img_state.to_device()
165
  masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
166
+ img_state.to_device('cpu')
167
 
168
  masks = masks[0, 0, :h, :w]
169
  masks = masks > 0.5
 
220
  )
221
 
222
  try:
223
+ img_state.to_device()
224
  masks, cls_pred = model.extract_masks(img_state.img_feat, prompts)
225
+ img_state.to_device('cpu')
226
 
227
  masks = masks[0, 0, :h, :w]
228
  masks = masks > 0.5
 
270
  img_tensor = F.pad(img_tensor, (0, IMG_SIZE - new_w, 0, IMG_SIZE - new_h), 'constant', 0)
271
  feat_dict = model.extract_feat(img_tensor)
272
  img_state.set_img(img_numpy, feat_dict)
273
+ img_state.to_device('cpu')
274
  print_log(f"Successfully generated the image feats.", logger='current')
275
  except RuntimeError as e:
276
  if "CUDA out of memory" in str(e):