Forecast4Muses / Model /Model6 /model6_inference.py
Hanxiao Xiang
upload
b328990
raw
history blame
No virus
13.5 kB
"""old name: test_runtime_model6.py"""
import json
import os
import subprocess
import sys
import warnings
from time import time
from typing import Union, Tuple, Any
import pandas as pd
from mmdet.apis import inference_detector
from mmdet.apis import init_detector as det_init_detector
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as pose_init_model
from mmpretrain import ImageClassificationInferencer
from mmpretrain.utils import register_all_modules
from .extensions.vis_pred_save import save_result
register_all_modules()
st = ist = time()
# irt = time() - st
# print(f'==Packages importing time is {irt}s==\n')
print('==Start==')
# DEVICE = 'cuda:0,1,2,3'
DEVICE = 'cpu'
abs_path = os.path.dirname(os.path.abspath(__file__))
yolo_config = os.path.join(abs_path, 'Model6_0_ClothesDetection/mmyolo/configs/custom_dataset/yolov6_s_fast.py')
yolo_checkpoint = os.path.join(abs_path, 'Model6_0_ClothesDetection/mmyolo/work_dirs/yolov6_s_df2_0.4/epoch_64.pth')
pretrain_config = os.path.join(abs_path, 'Model6_2_ProfileRecogition/mmpretrain/configs/resnext101_4xb32_2048e_3c_noF.py')
pretrain_checkpoint = os.path.join(abs_path, 'Model6_2_ProfileRecogition/mmpretrain/work_dirs/'
'resnext101_4xb32_2048e_3c_noF/best_accuracy_top1_epoch_1520.pth')
pose_configs = {
'short_sleeved_shirt': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb32-60e_deepfashion2_short_sleeved_shirt_256x192.py',
'long_sleeved_shirt': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-120e_deepfashion2_long_sleeved_shirt_256x192.py',
'short_sleeved_outwear': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb8-150e_deepfashion2_short_sleeved_outwear_256x192.py',
'long_sleeved_outwear': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb16-120e_deepfashion2_long_sleeved_outwear_256x192.py',
'vest': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-120e_deepfashion2_vest_256x192.py',
'sling': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-120e_deepfashion2_sling_256x192.py',
'shorts': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-210e_deepfashion2_shorts_256x192.py',
'trousers': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-60e_deepfashion2_trousers_256x192.py',
'skirt': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-120e_deepfashion2_skirt_256x192.py',
'short_sleeved_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-150e_deepfashion2_short_sleeved_dress_256x192.py',
'long_sleeved_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb16-150e_deepfashion2_long_sleeved_dress_256x192.py',
'vest_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-150e_deepfashion2_vest_dress_256x192.py',
'sling_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-210e_deepfashion2_sling_dress_256x192.py',
}
pose_checkpoints = {
'short_sleeved_shirt': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb32-60e_deepfashion2_short_sleeved_shirt_256x192/best_PCK_epoch_50.pth',
'long_sleeved_shirt': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-120e_deepfashion2_long_sleeved_shirt_256x192/best_PCK_epoch_60.pth',
'short_sleeved_outwear': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb8-150e_deepfashion2_short_sleeved_outwear_256x192/best_PCK_epoch_120.pth',
'long_sleeved_outwear': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb16-120e_deepfashion2_long_sleeved_outwear_256x192/best_PCK_epoch_100.pth',
'vest': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-120e_deepfashion2_vest_256x192/best_PCK_epoch_90.pth',
'sling': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-120e_deepfashion2_sling_256x192/best_PCK_epoch_60.pth',
'shorts': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-210e_deepfashion2_shorts_256x192/best_PCK_epoch_160.pth',
'trousers': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-60e_deepfashion2_trousers_256x192/best_PCK_epoch_30.pth',
'skirt': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-120e_deepfashion2_skirt_256x192/best_PCK_epoch_110.pth',
'short_sleeved_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-150e_deepfashion2_short_sleeved_dress_256x192/best_PCK_epoch_100.pth',
'long_sleeved_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb16-150e_deepfashion2_long_sleeved_dress_256x192/best_PCK_epoch_120.pth',
'vest_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-150e_deepfashion2_vest_dress_256x192/best_PCK_epoch_80.pth',
'sling_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-210e_deepfashion2_sling_dress_256x192/best_PCK_epoch_140.pth',
}
start_load = time()
yolo_inferencer = det_init_detector(yolo_config, yolo_checkpoint, device=DEVICE)
print('=' * 2 + 'The model loading time of MMYolo is {}s'.format(time() - start_load) + '=' * 2)
start_load = time()
pretrain_inferencer = ImageClassificationInferencer(model=pretrain_config,
pretrained=pretrain_checkpoint,
device=DEVICE)
print('=' * 2 + 'The model loading time of MMPretrain is {}s'.format(time() - start_load) + '=' * 2)
def get_bbox_results_by_classes(result) -> dict:
"""
:param result: the result of mmyolo inference
:return: a dict of bbox results by classes
"""
bbox_results_by_classes = {
'short_sleeved_shirt': [],
'long_sleeved_shirt': [],
'short_sleeved_outwear': [],
'long_sleeved_outwear': [],
'vest': [],
'sling': [],
'shorts': [],
'trousers': [],
'skirt': [],
'short_sleeved_dress': [],
'long_sleeved_dress': [],
'vest_dress': [],
'sling_dress': [],
}
pred_instances = result.pred_instances
_bboxes = pred_instances.bboxes
_labels = pred_instances.labels
_scores = pred_instances.scores
labels = _labels[[_scores > 0.3]]
bboxes = _bboxes[[_scores > 0.3]]
# use enumerate to get index and value
for idx, value in enumerate(labels):
class_name = list(bbox_results_by_classes.keys())[value]
x1 = bboxes[idx][0]
y1 = bboxes[idx][1]
x2 = bboxes[idx][2]
y2 = bboxes[idx][3]
bbox_results_by_classes[class_name].append([x1, y1, x2, y2])
return bbox_results_by_classes
def mmyolo_inference(img: Union[str, list], model) -> tuple:
mmyolo_st = time()
result = inference_detector(model, img)
mmyolo_et = time()
return result, (mmyolo_et - mmyolo_st)
def mmpose_inference(person_results: dict, use_bbox: bool,
mmyolo_cfg_path: str, mmyolo_ckf_path: str,
img: str, output_path_root: str, save=True, device='cpu') -> float:
"""
:param person_results: the result of mmyolo inference
:param use_bbox: whether to use bbox to inference the pose results
:param mmyolo_cfg_path: the file path of mmyolo config
:param mmyolo_ckf_path: the file path of mmyolo checkpoint
:param img: the path of the image to inference
:param output_path_root: the root path of the output
:param save: whether to save the inference result, including the image and the predicted json file.
If `save` is False, `output_path_root` will be invalid.
:param device: the device to inference
"""
mmpose_st = time()
poses = {
'short_sleeved_shirt': {},
'long_sleeved_shirt': {},
'short_sleeved_outwear': {},
'long_sleeved_outwear': {},
'vest': {},
'sling': {},
'shorts': {},
'trousers': {},
'skirt': {},
'short_sleeved_dress': {},
'long_sleeved_dress': {},
'vest_dress': {},
'sling_dress': {}
}
for label, person_result in person_results.items():
if len(person_result) == 0:
continue
pose_config = pose_configs[label]
pose_checkpoint = pose_checkpoints[label]
if not use_bbox:
from mmpose.apis import MMPoseInferencer
warnings.warn('use_bbox is False, '
'which means using MMPoseInferencer to inference the pose results without use_bbox '
'and may be wrong')
inferencer = MMPoseInferencer(
pose2d=pose_config,
pose2d_weights=pose_checkpoint,
det_model=mmyolo_cfg_path,
det_weights=mmyolo_ckf_path
)
result_generator = inferencer(img, out_dir='upload_to_web_tmp', return_vis=True)
result = next(result_generator)
# print(result)
else:
pose_model = pose_init_model(
pose_config,
pose_checkpoint,
device=device
)
pose_results = inference_topdown(pose_model, img, person_result, bbox_format='xyxy')
poses[label]['pose_results'] = pose_results
poses[label]['pose_model'] = pose_model
mmpose_et = time()
if save:
save_result(img, poses, out_dir=output_path_root)
return mmpose_et - mmpose_st
def mmpretrain_inference(img: Union[str, list], model) -> tuple:
mmpretain_st = time()
cls_result = model(img)
mmpretain_et = time()
return cls_result, (mmpretain_et - mmpretain_st)
def main(img_path: str, output_path_root='upload_to_web_tmp', use_bbox=True, device='cpu', test_runtime=False) -> dict:
"""
:param img_path: the path of the image or the folder of images
:param output_path_root: the root path of the output
:param use_bbox: whether to use bbox to inference the pose results
:param device: the device to inference
:param test_runtime: whether to test the runtime
:return: the results of model6_2 in form of dictionary
"""
if os.path.isdir(img_path):
img_names = os.listdir(img_path)
img_paths = [os.path.join(img_path, img_name) for img_name in img_names]
elif os.path.isfile(img_path):
img_paths = [img_path]
else:
print('==Img_path must be a path of an imgage or a folder!==')
raise ValueError()
runtimes = [['img_name',
'runtime_mmyolo', 'percent1',
'runtime_mmpose', 'percent2',
'runtime_mmpretrain', 'percent3',
'runtime_total']]
cls_results = {}
for img in img_paths:
print(f'==Start to inference {img}==')
yolo_result, runtime_mmyolo = mmyolo_inference(img, yolo_inferencer)
print(f'==mmyolo running time is {runtime_mmyolo}s==')
person_results = get_bbox_results_by_classes(yolo_result)
runtime_mmpose = mmpose_inference(
person_results=person_results,
use_bbox=use_bbox,
mmyolo_cfg_path=yolo_config,
mmyolo_ckf_path=yolo_checkpoint,
img=img,
output_path_root=output_path_root,
save=True,
device=device
)
print(f'mmpose running time is {runtime_mmpose}s')
cls_result, runtime_mmpretrain = mmpretrain_inference(img, pretrain_inferencer)
print(f'mmpretrain running time is {runtime_mmpretrain}s')
cls_results[os.path.basename(img)] = cls_result
if test_runtime:
runtime_total = runtime_mmyolo + runtime_mmpose + runtime_mmpretrain
percent1 = str(round(runtime_mmyolo / runtime_total * 100, 2)) + '%'
percent2 = str(round(runtime_mmpose / runtime_total * 100, 2)) + '%'
percent3 = str(round(runtime_mmpretrain / runtime_total * 100, 2)) + '%'
img_name = os.path.basename(img)
runtimes.append([img_name,
runtime_mmyolo, percent1,
runtime_mmpose, percent2,
runtime_mmpretrain, percent3,
runtime_total])
if test_runtime:
df = pd.DataFrame(runtimes, columns=runtimes[0])
df.to_csv('runtimes.csv', index=False)
return cls_results
if __name__ == "__main__":
# main(1)
main('data-test/')
# main('data-test/000002.jpg')
rt = time() - st
print(f'==Totol time cost is {rt}s==')