TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame contribute delete
No virus
2.62 kB
# flake8: noqa: E501
import os.path as osp
import mmengine
from opencompass.utils import dataset_abbr_from_cfg
def get_outdir(cfg, time_str):
"""Get out put path.
Args:
cfg (ConfigDict): The running config.
time_str (str): Current time.
"""
work_dir = cfg['work_dir']
output_path = osp.join(work_dir, 'summary', f'summary_{time_str}.txt')
output_dir = osp.join(osp.split(output_path)[0], f'{time_str}')
mmengine.mkdir_or_exist(output_dir)
results_folder = osp.join(work_dir, 'results')
return output_dir, results_folder
def get_judgeanswer_and_reference(dataset, subdir_path, post_process):
"""Extract judgements (scores) and references.
Args:
dataset (ConfigDict): Dataset config.
subdir_path (str): Model path in results dir.
post_process (function): The pre-defined extract function.
"""
dataset_abbr = dataset_abbr_from_cfg(dataset)
filename = osp.join(subdir_path, dataset_abbr + '.json')
partial_filename = osp.join(subdir_path, dataset_abbr + '_0.json')
if osp.exists(osp.realpath(filename)):
result = mmengine.load(filename)
elif osp.exists(osp.realpath(partial_filename)):
filename = partial_filename
result = {}
i = 1
partial_dict_flag = 0
while osp.exists(osp.realpath(filename)):
res = mmengine.load(filename)
for k, v in res.items():
result[partial_dict_flag] = v
partial_dict_flag += 1
filename = osp.join(subdir_path,
dataset_abbr + '_' + str(i) + '.json')
i += 1
else:
result = {}
if len(result) == 0:
print('*' * 100)
print('There are no results for ' + filename + ' or ' +
partial_filename)
print('*' * 100)
assert len(result) > 0
judged_answers = []
references = []
for k, v in result.items():
processed_judge = post_process(v['prediction'])
if processed_judge is not None:
judged_answers.append(processed_judge)
references.append(v['gold'])
if len(judged_answers) != len(result):
print(
f'Among {len(result)} judgements, successfully extracted {len(judged_answers)} judgements, please check!'
)
if len(judged_answers) == 0:
print('*' * 100)
print(
'There are no extracted judgements, please change your judge model or check your prompt!!!'
)
print('*' * 100)
assert len(judged_answers) > 0
return judged_answers, references