Hanxiao Xiang
upload
b328990
raw
history blame
No virus
8.05 kB
from .dataset_info import DatasetInfo
import cv2
import mmcv
import numpy as np
import os
from os import path as osp
import json
def save_result(img,
poses,
img_name=None,
radius=4,
thickness=1,
bbox_score_thr=None,
kpt_score_thr=0.3,
bbox_color='green',
dataset_info=None,
show=False,
out_dir=None,
vis_out_dir=None,
pred_out_dir=None,):
"""Visualize the detection results on the image.
Args:
img (str | np.ndarray): Image filename or loaded image.
poses (dict[dict]): a dict which contains pose_model and pose_results of different classes.
And the pose_results contains bboxes, bbox_scores, keypoints and keypoint_scores.
img_name (str): Image name.
radius (int): Radius of circles.
thickness (int): Thickness of lines.
bbox_score_thr (float): The threshold to visualize the bounding boxes.
kpt_score_thr (float): The threshold to visualize the keypoints.
bbox_color (str | tuple[int]): Color of bounding boxes.
dataset_info (DatasetInfo): Dataset info.
show (bool): Whether to show the image. Default False.
out_dir (str): The output directory to save the visualizations and predictions results.
If vis_out_dir is None, visualizations will be saved in ${out_dir}/visualizations.
If pred_out_dir is None, predictions will be saved in ${out_dir}/predictions.
Default None.
vis_out_dir (str): The output directory to save the visualization results. Default None.
pred_out_dir (str): The output directory to save the predictions results. Default None.
"""
# set flags
vis_out_flag = False if vis_out_dir is None else vis_out_dir
pred_out_flag = False if pred_out_dir is None else pred_out_dir
if out_dir:
if not vis_out_dir:
vis_out_flag = osp.join(out_dir, 'visualizations')
if not osp.exists(vis_out_flag):
os.mkdir(vis_out_flag)
if not pred_out_dir:
pred_out_flag = osp.join(out_dir, 'predictions')
if not osp.exists(pred_out_flag):
os.mkdir(pred_out_flag)
# read image
img_path = None
if isinstance(img, str):
img_path = img
img = mmcv.imread(img)
elif isinstance(img, np.ndarray):
img = img.copy()
else:
raise TypeError('img must be a filename or numpy array, '
f'but got {type(img)}')
bbox_list = []
label_list = []
class_name_list = []
bbox_score_list = []
idx = 0
for label, v in poses.items():
if len(v) == 0:
continue
pose_results = v['pose_results']
bbox = pose_results[0].gt_instances.bboxes
bbox_score = pose_results[0].gt_instances.bbox_scores
for bbox_idx in range(len(bbox)):
b = bbox[bbox_idx]
s = bbox_score[bbox_idx]
if bbox_score_thr is not None:
b = np.append(b, values=s) # switch to x1, y1, x2, y2, score
bbox_score_list.append(s.tolist())
bbox_list.append(b)
label_list.append(idx)
class_name_list.append(label)
idx += 1
bbox_list = np.array(bbox_list)
label_list = np.array(label_list)
# draw bbox
img = mmcv.imshow_det_bboxes(
img,
bbox_list,
label_list,
class_names=class_name_list,
score_thr=bbox_score_thr if bbox_score_thr is not None else 0,
bbox_color=bbox_color,
text_color='white',
show=False,
# out_file=out_file
)
keypoints_list = []
keypoint_scores_list = []
# draw pose of different classes
for label, v in poses.items():
if len(v) == 0:
continue
pose_model = v['pose_model']
pose_results = v['pose_results']
keypoints = pose_results[0].pred_instances.keypoints
for ks in keypoints:
keypoints_list.append(ks.tolist())
keypoint_scores = pose_results[0].pred_instances.keypoint_scores
for kss in keypoint_scores:
keypoint_scores_list.append(kss.tolist())
# get dataset info
if (dataset_info is None and hasattr(pose_model, 'cfg')
and 'dataset_info' in pose_model.cfg):
dataset_info = DatasetInfo(pose_model.cfg.dataset_info)
if dataset_info is not None:
skeleton = dataset_info.skeleton
pose_kpt_color = dataset_info.pose_kpt_color
pose_kpt_color_tmp = []
for color in pose_kpt_color:
pose_kpt_color_tmp.append(tuple([int(x) for x in color]))
pose_kpt_color = pose_kpt_color_tmp
pose_link_color = dataset_info.pose_link_color
pose_link_color_tmp = []
for color in pose_link_color:
pose_link_color_tmp.append(tuple([int(x) for x in color]))
pose_link_color = pose_link_color_tmp
else:
warnings.warn(
'dataset is deprecated.'
'Please set `dataset_info` in the config.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
DeprecationWarning)
raise ValueError('dataset_info is not specified or set in the config file.')
# create circles_list
circles_list = []
for bbox_idx, circles in enumerate(keypoints):
c_dict = {}
for c_idx, c in enumerate(circles):
if keypoint_scores[bbox_idx][c_idx] >= kpt_score_thr:
c_dict[c_idx] = c
# else:
# c_dict[c_idx] = None
circles_list.append(c_dict)
# create lines_list
lines_list = []
for bbox_idx, _ in enumerate(keypoints):
s_dict = {}
for s_idx, s in enumerate(skeleton):
if s[0] in circles_list[bbox_idx].keys() and s[1] in circles_list[bbox_idx].keys():
s_dict[s_idx] = True
else:
s_dict[s_idx] = False
lines_list.append(s_dict)
# draw circle
for _, circles in enumerate(circles_list):
for c_idx, c in circles.items():
if c is not None:
cv2.circle(img, (int(c[0]), int(c[1])), radius, pose_kpt_color[c_idx], -1)
# draw line
for bbox_idx, lines in enumerate(lines_list):
for l_idx, l in lines.items():
if l:
s = skeleton[l_idx][0] # idx of start point
e = skeleton[l_idx][1] # idx of end point
cv2.line(img,
(int(circles_list[bbox_idx][s][0]), int(circles_list[bbox_idx][s][1])),
(int(circles_list[bbox_idx][e][0]), int(circles_list[bbox_idx][e][1])),
pose_link_color[l_idx], thickness)
if show:
mmcv.imshow(img, wait_time=0)
if img_path is None:
if img_name is not None:
img_path = img_name
else:
img_path = 'demo.jpg'
if vis_out_flag:
out_file = osp.join(vis_out_flag, osp.basename(img_path))
mmcv.imwrite(img, out_file)
if pred_out_flag:
pred_list = []
for bbox_idx in range(len(bbox_list)):
bbl = bbox_list[bbox_idx].tolist()
pred_list.append(dict(
keypoints=keypoints_list[bbox_idx],
keypoint_scores=keypoint_scores_list[bbox_idx],
bbox=[bbl],
bbox_score=bbox_score_list[bbox_idx],
))
# replace .jpg or .png with .json
out_file = osp.join(pred_out_flag, osp.basename(img_path).rsplit('.', 1)[0] + '.json')
json.dump(pred_list, open(out_file, 'w'))
return img