TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame contribute delete
No virus
5.35 kB
import argparse
import json
import os
import os.path as osp
import random
import time
from typing import List, Sequence
import mmengine
import torch
import torch.distributed as dist
from mmengine.config import Config, ConfigDict
from mmengine.device import get_device
from mmengine.dist import init_dist
from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.model.wrappers import MMDistributedDataParallel
from mmengine.utils import track_iter_progress
from opencompass.registry import MM_MODELS, TASKS
from opencompass.utils import get_logger
def build_model(cfg):
model = MM_MODELS.build(cfg['model'])
load_from = cfg.get('load_from', None)
if load_from is not None:
state_dict = torch.load(cfg['load_from'], map_location='cpu')
if 'model' in state_dict:
state_dict = state_dict['model']
elif 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
msg = model.load_state_dict(state_dict, strict=False)
print_log(msg)
model.to(get_device())
if dist.is_initialized():
model = MMDistributedDataParallel(
model,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False)
return model
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
class MultimodalInferTask:
"""Multimodal Inference Task.
This task is used to run the inference process.
"""
def __init__(self, cfg: ConfigDict):
self.num_gpus = cfg.get('num_gpus', 0)
self.num_procs = cfg.get('num_procs', 1)
self.dataloader = cfg.get('dataset')
self.model = cfg.get('model')
self.evaluator = cfg.get('evaluator')
self.cfg = cfg
self.logger = get_logger()
@property
def name(self) -> str:
model_name = self.model['type']
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return f'{model_name}-{dataset_name}-{evaluator_name}'
def get_log_path(self, file_extension: str = 'json') -> str:
"""Get the path to the log file.
Args:
file_extension (str): The file extension of the log file.
Default: 'json'.
"""
model_name = self.model['type']
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return osp.join(self.cfg.work_dir, model_name, dataset_name,
f'{evaluator_name}.{file_extension}')
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the path to the output file.
Args:
file_extension (str): The file extension of the log file.
Default: 'json'.
"""
model_name = self.model['type']
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return [
osp.join(self.cfg.work_dir, model_name, dataset_name,
f'{evaluator_name}.{file_extension}')
]
def get_command(self, cfg_path, template):
"""Get the command template for the task.
Args:
cfg_path (str): The path to the config file of the task.
template (str): The template which have '{task_cmd}' to format
the command.
"""
script_path = __file__
if self.num_gpus > 0:
port = random.randint(12000, 32000)
command = (f'torchrun --master_port={port} '
f'--nproc_per_node {self.num_procs} '
f'{script_path} {cfg_path}')
else:
command = f'python {script_path} {cfg_path}'
return template.format(task_cmd=command)
def run(self):
from mmengine.runner import Runner
# only support slurm, pytorch, mpi
init_dist(self.cfg.launcher)
self.logger.info(f'Task {self.name}')
# build dataloader
dataloader = Runner.build_dataloader(self.dataloader)
# build model
model = build_model(self.cfg)
model.eval()
# build evaluator
evaluator = Evaluator(self.evaluator)
for batch in track_iter_progress(dataloader):
if dist.is_initialized():
data_samples = model.module.forward(batch)
else:
data_samples = model.forward(batch)
if not isinstance(data_samples, Sequence):
data_samples = [data_samples]
evaluator.process(data_samples)
metrics = evaluator.evaluate(len(dataloader.dataset))
metrics_file = self.get_output_paths()[0]
mmengine.mkdir_or_exist(osp.split(metrics_file)[0])
with open(metrics_file, 'w') as f:
json.dump(metrics, f)
def parse_args():
parser = argparse.ArgumentParser(description='Model Inferencer')
parser.add_argument('config', help='Config file path')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
cfg = Config.fromfile(args.config)
start_time = time.time()
inferencer = MultimodalInferTask(cfg)
inferencer.run()
end_time = time.time()
get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')