Omni-Judge / README.md
KbsdJames's picture
Add proper library tag (#1)
de5bdca verified
metadata
license: apache-2.0
library_name: transformers

Omni-Judge

Introduction

Omni-Judge is an open-source mathematical evaluation model designed to assess whether a solution generated by a model is correct given a problem and a standard answer. Due to the complexity of high-level mathematical problems and their solutions, designing rule-based evaluation methods can be challenging. Omni-Judge, similar to GPT-4-as-a-judge, offers automated assessment with greater efficiency and lower cost. For utilization details, please refer to this section.

Omni-Judge can be applied to various mathematical reasoning benchmarks, such as our proposed Omni-MATH.

Model Details

Omni-Judge builds on the meta-llama/Llama-3.1-8B-Instruct, incorporating GPT-4o evaluation data for instruction tuning. The training dataset comprises 17,618 examples, with a total of 2 epochs. Omni-Judge's performance is closely aligned with GPT-4o. We created an internal evaluation set using queries not previously seen by the model, consisting of 2,220 test samples. The agreement rate between Omni-Judge and GPT-4o evaluations is approximately 91%.

Quickstart

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "KbsdJames/Omni-Judge", 
    device_map="auto", 
    torch_dtype=torch.bfloat16, 
)
tokenizer = AutoTokenizer.from_pretrained(
    "KbsdJames/Omni-Judge", 
    trust_remote_code=True
)

# set terminators for decoding
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

question = "Compute the smallest positive integer $n$ for which $\\sqrt{100+\\sqrt{n}}+\\sqrt{100-\\sqrt{n}}$ is an integer."
reference_answer = "6156"
student_solution = "Let's approach this step-by-step:\n\n1) Let's denote $\\sqrt{100+\\sqrt{n}}+\\sqrt{100-\\sqrt{n}} = k$, where k is an integer.\n\n2) Square both sides:\n   $(\\sqrt{100+\\sqrt{n}}+\\sqrt{100-\\sqrt{n}})^2 = k^2$\n\n3) Expand the left side:\n   $100+\\sqrt{n} + 100-\\sqrt{n} + 2\\sqrt{(100+\\sqrt{n})(100-\\sqrt{n})} = k^2$\n\n4) Simplify:\n   $200 + 2\\sqrt{10000-n} = k^2$\n\n5) Rearrange:\n   $\\sqrt{10000-n} = \\frac{k^2-200}{2}$\n\n6) For this to be true, $\\frac{k^2-200}{2}$ must be a perfect square. Let's call it $m^2$:\n   $\\frac{k^2-200}{2} = m^2$\n\n7) Multiply both sides by 2:\n   $k^2-200 = 2m^2$\n\n8) Rearrange:\n   $k^2 = 2m^2 + 200$\n\n9) For the smallest positive n, we need the smallest k > 14 (since $\\sqrt{100+\\sqrt{n}}+\\sqrt{100-\\sqrt{n}} > 2\\sqrt{100} = 20$) that makes this equation true.\n\n10) Testing values:\n    For k = 15: $15^2 = 225 = 2m^2 + 200$\n                $25 = 2m^2$\n                $m^2 = 12.5$ (not an integer)\n\n    For k = 16: $16^2 = 256 = 2m^2 + 200$\n                $56 = 2m^2$\n                $m^2 = 28$ (integer)\n\n11) So, the smallest valid k is 16, and m = $\\sqrt{28}$\n\n12) Now, going back to step 5:\n    $\\sqrt{10000-n} = \\frac{16^2-200}{2} = 28$\n\n13) Solve for n:\n    $10000-n = 28^2 = 784$\n    $n = 10000 - 784 = 9216$\n\nTherefore, the smallest positive integer n for which $\\sqrt{100+\\sqrt{n}}+\\sqrt{100-\\sqrt{n}}$ is an integer is 9216."

# pre-process
formatted_context = tokenizer.get_context(
    question,
    reference_answer,
    student_solution,
)
model_inputs = tokenizer(formatted_context, return_tensors="pt")
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]

# do inference
pred = model.generate(
    input_ids=input_ids.to(model.device),
    attention_mask=attention_mask.to(model.device),
    do_sample = False,
    num_return_sequences = 1,
    max_new_tokens = 300,
)[0].cpu().tolist()

# post-process
pred = pred[len(input_ids[0].cpu().tolist()):]
for terminator in terminators:
    if terminator in pred:
        pred = pred[:pred.index(terminator)]
response = tokenizer.decode(pred, skip_special_tokens=True)
pred_truth = tokenizer.parse_response(response)

# if response parsing fails, the answer/judgement/justification will be None,
# which we consider as errors in prediction. 
# in this case, using multiple sampling may help.

print("answer:", pred_truth["answer"])
# >>> answer: 9216
print("judgement:", pred_truth["judgement"])
# >>> judgement: FALSE
print("justification:", pred_truth["justification"])
# >>> justification: The student's answer of 9216 does not match the reference answer of 6156. The student's solution involves a detailed process of finding the smallest positive integer n that satisfies the given condition, but the final result is incorrect. The discrepancy indicates that the student's answer does not share the same meaning as the reference answer.

Evaluation

Given GPT-4o judgement as the golden results, we report the performance of Omni-Judge.

For a fair comparison, the questions for train and test are different.

The results are shown below:

Source Success of Parsing Consistency
MetaLlama-3.1-70B-instruct 99.76 82.19
DeepSeek-Coder-V2 100 94.01
Qwen2.5-MATH-7b-Instruct 100 90.69
OpenAI o1-preview 99.78 91.28
OpenAI o1-mini 100 91.78
Mathstral-7B-v0.1 100 95.79
NuminaMATH-72B-COT 100 90.44
Qwen2.5-MATH-72b-Instruct 100 93.30
All 99.94 91.26

Citation

If you find our work interesting and meaningful, welcome to give a star to our repo and cite our paper.

@misc{gao2024omnimathuniversalolympiadlevel,
      title={Omni-MATH: A Universal Olympiad Level Mathematic Benchmark For Large Language Models}, 
      author={Bofei Gao and Feifan Song and Zhe Yang and Zefan Cai and Yibo Miao and Qingxiu Dong and Lei Li and Chenghao Ma and Liang Chen and Runxin Xu and Zhengyang Tang and Benyou Wang and Daoguang Zan and Shanghaoran Quan and Ge Zhang and Lei Sha and Yichang Zhang and Xuancheng Ren and Tianyu Liu and Baobao Chang},
      year={2024},
      eprint={2410.07985},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2410.07985}, 
}