File size: 8,047 Bytes
b328990
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from .dataset_info import DatasetInfo
import cv2
import mmcv
import numpy as np
import os
from os import path as osp
import json


def save_result(img,
                poses,
                img_name=None,
                radius=4,
                thickness=1,
                bbox_score_thr=None,
                kpt_score_thr=0.3,
                bbox_color='green',
                dataset_info=None,
                show=False,
                out_dir=None,
                vis_out_dir=None,
                pred_out_dir=None,):
    """Visualize the detection results on the image.

    Args:
        img (str | np.ndarray): Image filename or loaded image.
        poses (dict[dict]): a dict which contains pose_model and pose_results of different classes.
                And the pose_results contains bboxes, bbox_scores, keypoints and keypoint_scores.
        img_name (str): Image name.
        radius (int): Radius of circles.
        thickness (int): Thickness of lines.
        bbox_score_thr (float): The threshold to visualize the bounding boxes.
        kpt_score_thr (float): The threshold to visualize the keypoints.
        bbox_color (str | tuple[int]): Color of bounding boxes.
        dataset_info (DatasetInfo): Dataset info.
        show (bool):  Whether to show the image. Default False.
        out_dir (str): The output directory to save the visualizations and predictions results.
                If vis_out_dir is None, visualizations will be saved in ${out_dir}/visualizations.
                If pred_out_dir is None, predictions will be saved in ${out_dir}/predictions.
                Default None.
        vis_out_dir (str): The output directory to save the visualization results. Default None.
        pred_out_dir (str): The output directory to save the predictions results. Default None.
    """
    # set flags
    vis_out_flag = False if vis_out_dir is None else vis_out_dir
    pred_out_flag = False if pred_out_dir is None else pred_out_dir
    if out_dir:
        if not vis_out_dir:
            vis_out_flag = osp.join(out_dir, 'visualizations')
            if not osp.exists(vis_out_flag):
                os.mkdir(vis_out_flag)
        if not pred_out_dir:
            pred_out_flag = osp.join(out_dir, 'predictions')
            if not osp.exists(pred_out_flag):
                os.mkdir(pred_out_flag)

    # read image
    img_path = None
    if isinstance(img, str):
        img_path = img
        img = mmcv.imread(img)
    elif isinstance(img, np.ndarray):
        img = img.copy()
    else:
        raise TypeError('img must be a filename or numpy array, '
                        f'but got {type(img)}')
    bbox_list = []
    label_list = []
    class_name_list = []
    bbox_score_list = []
    idx = 0
    for label, v in poses.items():
        if len(v) == 0:
            continue
        pose_results = v['pose_results']
        bbox = pose_results[0].gt_instances.bboxes
        bbox_score = pose_results[0].gt_instances.bbox_scores
        for bbox_idx in range(len(bbox)):
            b = bbox[bbox_idx]
            s = bbox_score[bbox_idx]
            if bbox_score_thr is not None:
                b = np.append(b, values=s)  # switch to x1, y1, x2, y2, score
            bbox_score_list.append(s.tolist())
            bbox_list.append(b)
            label_list.append(idx)
            class_name_list.append(label)
        idx += 1
    bbox_list = np.array(bbox_list)
    label_list = np.array(label_list)

    # draw bbox
    img = mmcv.imshow_det_bboxes(
        img,
        bbox_list,
        label_list,
        class_names=class_name_list,
        score_thr=bbox_score_thr if bbox_score_thr is not None else 0,
        bbox_color=bbox_color,
        text_color='white',
        show=False,
        # out_file=out_file
    )

    keypoints_list = []
    keypoint_scores_list = []
    # draw pose of different classes
    for label, v in poses.items():
        if len(v) == 0:
            continue
        pose_model = v['pose_model']
        pose_results = v['pose_results']
        keypoints = pose_results[0].pred_instances.keypoints
        for ks in keypoints:
            keypoints_list.append(ks.tolist())
        keypoint_scores = pose_results[0].pred_instances.keypoint_scores
        for kss in keypoint_scores:
            keypoint_scores_list.append(kss.tolist())

        # get dataset info
        if (dataset_info is None and hasattr(pose_model, 'cfg')
                and 'dataset_info' in pose_model.cfg):
            dataset_info = DatasetInfo(pose_model.cfg.dataset_info)

        if dataset_info is not None:
            skeleton = dataset_info.skeleton

            pose_kpt_color = dataset_info.pose_kpt_color
            pose_kpt_color_tmp = []
            for color in pose_kpt_color:
                pose_kpt_color_tmp.append(tuple([int(x) for x in color]))
            pose_kpt_color = pose_kpt_color_tmp

            pose_link_color = dataset_info.pose_link_color
            pose_link_color_tmp = []
            for color in pose_link_color:
                pose_link_color_tmp.append(tuple([int(x) for x in color]))
            pose_link_color = pose_link_color_tmp
        else:
            warnings.warn(
                'dataset is deprecated.'
                'Please set `dataset_info` in the config.'
                'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
                DeprecationWarning)
            raise ValueError('dataset_info is not specified or set in the config file.')

        # create circles_list
        circles_list = []
        for bbox_idx, circles in enumerate(keypoints):
            c_dict = {}
            for c_idx, c in enumerate(circles):
                if keypoint_scores[bbox_idx][c_idx] >= kpt_score_thr:
                    c_dict[c_idx] = c
                # else:
                #     c_dict[c_idx] = None
            circles_list.append(c_dict)

        # create lines_list
        lines_list = []
        for bbox_idx, _ in enumerate(keypoints):
            s_dict = {}
            for s_idx, s in enumerate(skeleton):
                if s[0] in circles_list[bbox_idx].keys() and s[1] in circles_list[bbox_idx].keys():
                    s_dict[s_idx] = True
                else:
                    s_dict[s_idx] = False
            lines_list.append(s_dict)

        # draw circle
        for _, circles in enumerate(circles_list):
            for c_idx, c in circles.items():
                if c is not None:
                    cv2.circle(img, (int(c[0]), int(c[1])), radius, pose_kpt_color[c_idx], -1)

        # draw line
        for bbox_idx, lines in enumerate(lines_list):
            for l_idx, l in lines.items():
                if l:
                    s = skeleton[l_idx][0]  # idx of start point
                    e = skeleton[l_idx][1]  # idx of end point
                    cv2.line(img,
                             (int(circles_list[bbox_idx][s][0]), int(circles_list[bbox_idx][s][1])),
                             (int(circles_list[bbox_idx][e][0]), int(circles_list[bbox_idx][e][1])),
                             pose_link_color[l_idx], thickness)

    if show:
        mmcv.imshow(img, wait_time=0)
    if img_path is None:
        if img_name is not None:
            img_path = img_name
        else:
            img_path = 'demo.jpg'
    if vis_out_flag:
        out_file = osp.join(vis_out_flag, osp.basename(img_path))
        mmcv.imwrite(img, out_file)
    if pred_out_flag:
        pred_list = []
        for bbox_idx in range(len(bbox_list)):
            bbl = bbox_list[bbox_idx].tolist()
            pred_list.append(dict(
                keypoints=keypoints_list[bbox_idx],
                keypoint_scores=keypoint_scores_list[bbox_idx],
                bbox=[bbl],
                bbox_score=bbox_score_list[bbox_idx],
            ))
        # replace .jpg or .png with .json
        out_file = osp.join(pred_out_flag, osp.basename(img_path).rsplit('.', 1)[0] + '.json')
        json.dump(pred_list, open(out_file, 'w'))

    return img