bthndmn12 commited on
Commit
87757c1
1 Parent(s): 06a19d8

Fixed some bugs

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -17,7 +17,7 @@ model = SamModel.from_pretrained("./checkpoint",local_files_only=True)
17
  def get_bbox(gt_map):
18
 
19
  if gt_map.ndim > 2:
20
- gt_map = gt_map[:, :, 0] # Assuming the mask is the same across all channels
21
 
22
  # Check if the ground truth map is empty
23
  if np.sum(gt_map) == 0:
@@ -57,7 +57,7 @@ def process_image(image_input):
57
  outputs = model(**inputs, multimask_output=False)
58
 
59
  # Process model output
60
- seg_prob = torch.sigmoid(outputs['pred_masks'].squeeze(1))
61
  seg_prob = seg_prob.cpu().numpy().squeeze()
62
  seg = (seg_prob > 0.5).astype(np.uint8)
63
 
 
17
  def get_bbox(gt_map):
18
 
19
  if gt_map.ndim > 2:
20
+ gt_map = gt_map[:, :, 0]
21
 
22
  # Check if the ground truth map is empty
23
  if np.sum(gt_map) == 0:
 
57
  outputs = model(**inputs, multimask_output=False)
58
 
59
  # Process model output
60
+ seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
61
  seg_prob = seg_prob.cpu().numpy().squeeze()
62
  seg = (seg_prob > 0.5).astype(np.uint8)
63