import numpy as np from seg.models.utils import NO_OBJ, INSTANCE_OFFSET_HB def vpq_eval(element, num_classes=-1, max_ins=INSTANCE_OFFSET_HB, ign_id=NO_OBJ): assert num_classes != -1 import six pred_ids, gt_ids = element offset = 1e7 # 1e7 > 200 * max_ins assert offset > num_classes * max_ins num_cat = num_classes + 1 iou_per_class = np.zeros(num_cat, dtype=np.float64) tp_per_class = np.zeros(num_cat, dtype=np.float64) fn_per_class = np.zeros(num_cat, dtype=np.float64) fp_per_class = np.zeros(num_cat, dtype=np.float64) def _ids_to_counts(id_array): ids, counts = np.unique(id_array, return_counts=True) return dict(six.moves.zip(ids, counts)) pred_areas = _ids_to_counts(pred_ids) gt_areas = _ids_to_counts(gt_ids) void_id = ign_id * max_ins ign_ids = { gt_id for gt_id in six.iterkeys(gt_areas) if (gt_id // max_ins) == ign_id } int_ids = gt_ids.astype(np.uint64) * offset + pred_ids.astype(np.uint64) int_areas = _ids_to_counts(int_ids) def prediction_void_overlap(pred_id): void_int_id = void_id * offset + pred_id return int_areas.get(void_int_id, 0) def prediction_ignored_overlap(pred_id): total_ignored_overlap = 0 for _ign_id in ign_ids: int_id = _ign_id * offset + pred_id total_ignored_overlap += int_areas.get(int_id, 0) return total_ignored_overlap gt_matched = set() pred_matched = set() for int_id, int_area in six.iteritems(int_areas): gt_id = int(int_id // offset) gt_cat = int(gt_id // max_ins) pred_id = int(int_id % offset) pred_cat = int(pred_id // max_ins) if gt_cat != pred_cat: continue union = ( gt_areas[gt_id] + pred_areas[pred_id] - int_area - prediction_void_overlap(pred_id) ) iou = int_area / union if iou > 0.5: tp_per_class[gt_cat] += 1 iou_per_class[gt_cat] += iou gt_matched.add(gt_id) pred_matched.add(pred_id) for gt_id in six.iterkeys(gt_areas): if gt_id in gt_matched: continue cat_id = gt_id // max_ins if cat_id == ign_id: continue fn_per_class[cat_id] += 1 for pred_id in six.iterkeys(pred_areas): if pred_id in pred_matched: continue if (prediction_ignored_overlap(pred_id) / pred_areas[pred_id]) > 0.5: continue cat = pred_id // max_ins fp_per_class[cat] += 1 return iou_per_class, tp_per_class, fn_per_class, fp_per_class def stq(element, num_classes=19, max_ins=10000, ign_id=NO_OBJ, num_things=8, label_divisor=1e4, ins_divisor=1e7): y_pred, y_true = element y_true = y_true.astype(np.int64) y_pred = y_pred.astype(np.int64) # semantic eval semantic_label = y_true // max_ins semantic_prediction = y_pred // max_ins semantic_label = np.where(semantic_label != ign_id, semantic_label, num_classes) semantic_prediction = np.where(semantic_prediction != ign_id, semantic_prediction, num_classes) semantic_ids = np.reshape(semantic_label, [-1]) * label_divisor + np.reshape(semantic_prediction, [-1]) # instance eval instance_label = y_true % max_ins label_mask = np.less(semantic_label, num_things) prediction_mask = np.less(semantic_label, num_things) is_crowd = np.logical_and(instance_label == 0, label_mask) label_mask = np.logical_and(label_mask, np.logical_not(is_crowd)) prediction_mask = np.logical_and(prediction_mask, np.logical_not(is_crowd)) seq_preds = y_pred[prediction_mask] seg_labels = y_true[label_mask] non_crowd_intersection = np.logical_and(label_mask, prediction_mask) intersection_ids = (y_true[non_crowd_intersection] * ins_divisor + y_pred[non_crowd_intersection]) return semantic_ids, seq_preds, seg_labels, intersection_ids