next / llava /train /llava_trainer_eval.py
BiXie's picture
Upload 204 files
252711e verified
import json
import subprocess
from llava.train.llava_trainer import LLaVATrainer
class LLaVAEvalTrainer(LLaVATrainer):
def evaluate(self, evaluate_args):
cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \
--model {evaluate_args.model} \
--model_args {evaluate_args.model_args} \
--tasks {evaluate_args.task_names} \
--batch_size {evaluate_args.batch_size} \
--log_samples_suffix {evaluate_args.log_samples_suffix} \
--output_path {evaluate_args.output_path}"
if evaluate_args.limit:
cmd += f" --limit {evaluate_args.limit}"
if evaluate_args.num_fewshot:
cmd += f" --num_fewshot {evaluate_args.num_fewshot}"
if evaluate_args.gen_kwargs != "":
cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}"
if evaluate_args.log_samples:
cmd += f" --log_samples"
else:
assert False, "Please log samples so that the result can be parsed"
results = subprocess.run([cmd], shell=True, capture_output=True, text=True)
try:
result_file_index_start = results.stdout.index("Saved samples to ")
result_file_index_end = results.stdout.index(f".json")
result_file_index_start += len("Saved samples to ")
file = results.stdout[result_file_index_start:result_file_index_end]
except:
result_file_index_start = results.stderr.index("Saved samples to ")
result_file_index_end = results.stderr.index(f".json")
result_file_index_start += len("Saved samples to ")
file = results.stderr[result_file_index_start:result_file_index_end]
file = file.split("/")[:-1]
file = "/".join(file) + "/results.json"
with open(file, "r") as f:
lmms_eval_results = json.load(f)
result_dict = {}
tasks_list = evaluate_args.task_names.split(",")
for task in tasks_list:
task_results = lmms_eval_results["results"][task]
for k, v in task_results.items():
if k != "alias" and "stderr" not in k:
metric = k.split(",")[0]
result_dict[f"{task}_{metric}"] = v
return result_dict
"""def evaluate(self, evaluate_args):
initialize_tasks()
tasks_list = evaluate_args.task_names.split(",")
result_dict = {}
results = evaluator.simple_evaluate(
model=evaluate_args.model,
model_args=evaluate_args.model_args,
tasks=tasks_list,
num_fewshot=evaluate_args.num_fewshot,
batch_size=evaluate_args.batch_size,
device=evaluate_args.device,
limit=evaluate_args.limit,
check_integrity=evaluate_args.check_integrity,
show_task_to_terminal=evaluate_args.show_task_to_terminal,
log_samples=evaluate_args.log_samples,
gen_kwargs=evaluate_args.gen_kwargs,
cli_args=evaluate_args,
)
for task in tasks_list:
task_results = results["results"][task]
for k,v in task_results.items():
if k != "alias" and "stderr" not in k:
metric = k.split(",")[0]
result_dict[f"{task}_{metric}"] = v
return result_dict"""