from .dataset_info import DatasetInfo import cv2 import mmcv import numpy as np import os from os import path as osp import json def save_result(img, poses, img_name=None, radius=4, thickness=1, bbox_score_thr=None, kpt_score_thr=0.3, bbox_color='green', dataset_info=None, show=False, out_dir=None, vis_out_dir=None, pred_out_dir=None,): """Visualize the detection results on the image. Args: img (str | np.ndarray): Image filename or loaded image. poses (dict[dict]): a dict which contains pose_model and pose_results of different classes. And the pose_results contains bboxes, bbox_scores, keypoints and keypoint_scores. img_name (str): Image name. radius (int): Radius of circles. thickness (int): Thickness of lines. bbox_score_thr (float): The threshold to visualize the bounding boxes. kpt_score_thr (float): The threshold to visualize the keypoints. bbox_color (str | tuple[int]): Color of bounding boxes. dataset_info (DatasetInfo): Dataset info. show (bool): Whether to show the image. Default False. out_dir (str): The output directory to save the visualizations and predictions results. If vis_out_dir is None, visualizations will be saved in ${out_dir}/visualizations. If pred_out_dir is None, predictions will be saved in ${out_dir}/predictions. Default None. vis_out_dir (str): The output directory to save the visualization results. Default None. pred_out_dir (str): The output directory to save the predictions results. Default None. """ # set flags vis_out_flag = False if vis_out_dir is None else vis_out_dir pred_out_flag = False if pred_out_dir is None else pred_out_dir if out_dir: if not vis_out_dir: vis_out_flag = osp.join(out_dir, 'visualizations') if not osp.exists(vis_out_flag): os.mkdir(vis_out_flag) if not pred_out_dir: pred_out_flag = osp.join(out_dir, 'predictions') if not osp.exists(pred_out_flag): os.mkdir(pred_out_flag) # read image img_path = None if isinstance(img, str): img_path = img img = mmcv.imread(img) elif isinstance(img, np.ndarray): img = img.copy() else: raise TypeError('img must be a filename or numpy array, ' f'but got {type(img)}') bbox_list = [] label_list = [] class_name_list = [] bbox_score_list = [] idx = 0 for label, v in poses.items(): if len(v) == 0: continue pose_results = v['pose_results'] bbox = pose_results[0].gt_instances.bboxes bbox_score = pose_results[0].gt_instances.bbox_scores for bbox_idx in range(len(bbox)): b = bbox[bbox_idx] s = bbox_score[bbox_idx] if bbox_score_thr is not None: b = np.append(b, values=s) # switch to x1, y1, x2, y2, score bbox_score_list.append(s.tolist()) bbox_list.append(b) label_list.append(idx) class_name_list.append(label) idx += 1 bbox_list = np.array(bbox_list) label_list = np.array(label_list) # draw bbox img = mmcv.imshow_det_bboxes( img, bbox_list, label_list, class_names=class_name_list, score_thr=bbox_score_thr if bbox_score_thr is not None else 0, bbox_color=bbox_color, text_color='white', show=False, # out_file=out_file ) keypoints_list = [] keypoint_scores_list = [] # draw pose of different classes for label, v in poses.items(): if len(v) == 0: continue pose_model = v['pose_model'] pose_results = v['pose_results'] keypoints = pose_results[0].pred_instances.keypoints for ks in keypoints: keypoints_list.append(ks.tolist()) keypoint_scores = pose_results[0].pred_instances.keypoint_scores for kss in keypoint_scores: keypoint_scores_list.append(kss.tolist()) # get dataset info if (dataset_info is None and hasattr(pose_model, 'cfg') and 'dataset_info' in pose_model.cfg): dataset_info = DatasetInfo(pose_model.cfg.dataset_info) if dataset_info is not None: skeleton = dataset_info.skeleton pose_kpt_color = dataset_info.pose_kpt_color pose_kpt_color_tmp = [] for color in pose_kpt_color: pose_kpt_color_tmp.append(tuple([int(x) for x in color])) pose_kpt_color = pose_kpt_color_tmp pose_link_color = dataset_info.pose_link_color pose_link_color_tmp = [] for color in pose_link_color: pose_link_color_tmp.append(tuple([int(x) for x in color])) pose_link_color = pose_link_color_tmp else: warnings.warn( 'dataset is deprecated.' 'Please set `dataset_info` in the config.' 'Check https://github.com/open-mmlab/mmpose/pull/663 for details.', DeprecationWarning) raise ValueError('dataset_info is not specified or set in the config file.') # create circles_list circles_list = [] for bbox_idx, circles in enumerate(keypoints): c_dict = {} for c_idx, c in enumerate(circles): if keypoint_scores[bbox_idx][c_idx] >= kpt_score_thr: c_dict[c_idx] = c # else: # c_dict[c_idx] = None circles_list.append(c_dict) # create lines_list lines_list = [] for bbox_idx, _ in enumerate(keypoints): s_dict = {} for s_idx, s in enumerate(skeleton): if s[0] in circles_list[bbox_idx].keys() and s[1] in circles_list[bbox_idx].keys(): s_dict[s_idx] = True else: s_dict[s_idx] = False lines_list.append(s_dict) # draw circle for _, circles in enumerate(circles_list): for c_idx, c in circles.items(): if c is not None: cv2.circle(img, (int(c[0]), int(c[1])), radius, pose_kpt_color[c_idx], -1) # draw line for bbox_idx, lines in enumerate(lines_list): for l_idx, l in lines.items(): if l: s = skeleton[l_idx][0] # idx of start point e = skeleton[l_idx][1] # idx of end point cv2.line(img, (int(circles_list[bbox_idx][s][0]), int(circles_list[bbox_idx][s][1])), (int(circles_list[bbox_idx][e][0]), int(circles_list[bbox_idx][e][1])), pose_link_color[l_idx], thickness) if show: mmcv.imshow(img, wait_time=0) if img_path is None: if img_name is not None: img_path = img_name else: img_path = 'demo.jpg' if vis_out_flag: out_file = osp.join(vis_out_flag, osp.basename(img_path)) mmcv.imwrite(img, out_file) if pred_out_flag: pred_list = [] for bbox_idx in range(len(bbox_list)): bbl = bbox_list[bbox_idx].tolist() pred_list.append(dict( keypoints=keypoints_list[bbox_idx], keypoint_scores=keypoint_scores_list[bbox_idx], bbox=[bbl], bbox_score=bbox_score_list[bbox_idx], )) # replace .jpg or .png with .json out_file = osp.join(pred_out_flag, osp.basename(img_path).rsplit('.', 1)[0] + '.json') json.dump(pred_list, open(out_file, 'w')) return img