|
import json
|
|
import subprocess
|
|
|
|
from llava.train.llava_trainer import LLaVATrainer
|
|
|
|
|
|
class LLaVAEvalTrainer(LLaVATrainer):
|
|
def evaluate(self, evaluate_args):
|
|
cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \
|
|
--model {evaluate_args.model} \
|
|
--model_args {evaluate_args.model_args} \
|
|
--tasks {evaluate_args.task_names} \
|
|
--batch_size {evaluate_args.batch_size} \
|
|
--log_samples_suffix {evaluate_args.log_samples_suffix} \
|
|
--output_path {evaluate_args.output_path}"
|
|
if evaluate_args.limit:
|
|
cmd += f" --limit {evaluate_args.limit}"
|
|
if evaluate_args.num_fewshot:
|
|
cmd += f" --num_fewshot {evaluate_args.num_fewshot}"
|
|
if evaluate_args.gen_kwargs != "":
|
|
cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}"
|
|
if evaluate_args.log_samples:
|
|
cmd += f" --log_samples"
|
|
else:
|
|
assert False, "Please log samples so that the result can be parsed"
|
|
results = subprocess.run([cmd], shell=True, capture_output=True, text=True)
|
|
try:
|
|
result_file_index_start = results.stdout.index("Saved samples to ")
|
|
result_file_index_end = results.stdout.index(f".json")
|
|
result_file_index_start += len("Saved samples to ")
|
|
file = results.stdout[result_file_index_start:result_file_index_end]
|
|
except:
|
|
result_file_index_start = results.stderr.index("Saved samples to ")
|
|
result_file_index_end = results.stderr.index(f".json")
|
|
result_file_index_start += len("Saved samples to ")
|
|
file = results.stderr[result_file_index_start:result_file_index_end]
|
|
file = file.split("/")[:-1]
|
|
file = "/".join(file) + "/results.json"
|
|
with open(file, "r") as f:
|
|
lmms_eval_results = json.load(f)
|
|
result_dict = {}
|
|
tasks_list = evaluate_args.task_names.split(",")
|
|
for task in tasks_list:
|
|
task_results = lmms_eval_results["results"][task]
|
|
for k, v in task_results.items():
|
|
if k != "alias" and "stderr" not in k:
|
|
metric = k.split(",")[0]
|
|
result_dict[f"{task}_{metric}"] = v
|
|
return result_dict
|
|
|
|
"""def evaluate(self, evaluate_args):
|
|
initialize_tasks()
|
|
tasks_list = evaluate_args.task_names.split(",")
|
|
result_dict = {}
|
|
results = evaluator.simple_evaluate(
|
|
model=evaluate_args.model,
|
|
model_args=evaluate_args.model_args,
|
|
tasks=tasks_list,
|
|
num_fewshot=evaluate_args.num_fewshot,
|
|
batch_size=evaluate_args.batch_size,
|
|
device=evaluate_args.device,
|
|
limit=evaluate_args.limit,
|
|
check_integrity=evaluate_args.check_integrity,
|
|
show_task_to_terminal=evaluate_args.show_task_to_terminal,
|
|
log_samples=evaluate_args.log_samples,
|
|
gen_kwargs=evaluate_args.gen_kwargs,
|
|
cli_args=evaluate_args,
|
|
)
|
|
for task in tasks_list:
|
|
task_results = results["results"][task]
|
|
for k,v in task_results.items():
|
|
if k != "alias" and "stderr" not in k:
|
|
metric = k.split(",")[0]
|
|
result_dict[f"{task}_{metric}"] = v
|
|
|
|
return result_dict"""
|
|
|