OMG_Seg / seg /models /utils /offline_video_metrics.py
HarborYuan's picture
add omg code
b34d1d6
raw
history blame
No virus
4.05 kB
import numpy as np
from seg.models.utils import NO_OBJ, INSTANCE_OFFSET_HB
def vpq_eval(element, num_classes=-1, max_ins=INSTANCE_OFFSET_HB, ign_id=NO_OBJ):
assert num_classes != -1
import six
pred_ids, gt_ids = element
offset = 1e7 # 1e7 > 200 * max_ins
assert offset > num_classes * max_ins
num_cat = num_classes + 1
iou_per_class = np.zeros(num_cat, dtype=np.float64)
tp_per_class = np.zeros(num_cat, dtype=np.float64)
fn_per_class = np.zeros(num_cat, dtype=np.float64)
fp_per_class = np.zeros(num_cat, dtype=np.float64)
def _ids_to_counts(id_array):
ids, counts = np.unique(id_array, return_counts=True)
return dict(six.moves.zip(ids, counts))
pred_areas = _ids_to_counts(pred_ids)
gt_areas = _ids_to_counts(gt_ids)
void_id = ign_id * max_ins
ign_ids = {
gt_id for gt_id in six.iterkeys(gt_areas)
if (gt_id // max_ins) == ign_id
}
int_ids = gt_ids.astype(np.uint64) * offset + pred_ids.astype(np.uint64)
int_areas = _ids_to_counts(int_ids)
def prediction_void_overlap(pred_id):
void_int_id = void_id * offset + pred_id
return int_areas.get(void_int_id, 0)
def prediction_ignored_overlap(pred_id):
total_ignored_overlap = 0
for _ign_id in ign_ids:
int_id = _ign_id * offset + pred_id
total_ignored_overlap += int_areas.get(int_id, 0)
return total_ignored_overlap
gt_matched = set()
pred_matched = set()
for int_id, int_area in six.iteritems(int_areas):
gt_id = int(int_id // offset)
gt_cat = int(gt_id // max_ins)
pred_id = int(int_id % offset)
pred_cat = int(pred_id // max_ins)
if gt_cat != pred_cat:
continue
union = (
gt_areas[gt_id] + pred_areas[pred_id] - int_area -
prediction_void_overlap(pred_id)
)
iou = int_area / union
if iou > 0.5:
tp_per_class[gt_cat] += 1
iou_per_class[gt_cat] += iou
gt_matched.add(gt_id)
pred_matched.add(pred_id)
for gt_id in six.iterkeys(gt_areas):
if gt_id in gt_matched:
continue
cat_id = gt_id // max_ins
if cat_id == ign_id:
continue
fn_per_class[cat_id] += 1
for pred_id in six.iterkeys(pred_areas):
if pred_id in pred_matched:
continue
if (prediction_ignored_overlap(pred_id) / pred_areas[pred_id]) > 0.5:
continue
cat = pred_id // max_ins
fp_per_class[cat] += 1
return iou_per_class, tp_per_class, fn_per_class, fp_per_class
def stq(element, num_classes=19, max_ins=10000, ign_id=NO_OBJ, num_things=8, label_divisor=1e4, ins_divisor=1e7):
y_pred, y_true = element
y_true = y_true.astype(np.int64)
y_pred = y_pred.astype(np.int64)
# semantic eval
semantic_label = y_true // max_ins
semantic_prediction = y_pred // max_ins
semantic_label = np.where(semantic_label != ign_id,
semantic_label, num_classes)
semantic_prediction = np.where(semantic_prediction != ign_id,
semantic_prediction, num_classes)
semantic_ids = np.reshape(semantic_label, [-1]) * label_divisor + np.reshape(semantic_prediction, [-1])
# instance eval
instance_label = y_true % max_ins
label_mask = np.less(semantic_label, num_things)
prediction_mask = np.less(semantic_label, num_things)
is_crowd = np.logical_and(instance_label == 0, label_mask)
label_mask = np.logical_and(label_mask, np.logical_not(is_crowd))
prediction_mask = np.logical_and(prediction_mask, np.logical_not(is_crowd))
seq_preds = y_pred[prediction_mask]
seg_labels = y_true[label_mask]
non_crowd_intersection = np.logical_and(label_mask, prediction_mask)
intersection_ids = (y_true[non_crowd_intersection] * ins_divisor + y_pred[non_crowd_intersection])
return semantic_ids, seq_preds, seg_labels, intersection_ids