# MIT License # Copyright (c) 2022 Alireza Mohammadshahi # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """ RQUGE metric. """ import functools from contextlib import contextmanager from rquge_score.scorer import RQUGE import datasets from packaging import version import evaluate @contextmanager def filter_logging_context(): def filter_log(record): return False if "This IS expected if you are initializing" in record.msg else True logger = datasets.utils.logging.get_logger("transformers.modeling_utils") logger.addFilter(filter_log) try: yield finally: logger.removeFilter(filter_log) _CITATION = """\ @misc{mohammadshahi2022rquge, title={RQUGE: Reference-Free Metric for Evaluating Question Generation by Answering the Question}, author={Alireza Mohammadshahi and Thomas Scialom and Majid Yazdani and Pouya Yanki and Angela Fan and James Henderson and Marzieh Saeidi}, year={2022}, eprint={2211.01482}, archivePrefix={arXiv}, primaryClass={cs.CL} } """ _DESCRIPTION = """\ RQUGE, a Reference-free QUestion Generation Evaluation metric that can compute the quality of the candidate question without requiring the access to the reference question. Given the corresponding context and answer span, our metric calculates the acceptability score by applying a general question-answering module, followed by a span scorer. You can find more detail in the paper (https://arxiv.org/abs/2211.01482) (ACL2023). """ _KWARGS_DESCRIPTION = """ RQUGE Metric to compute the acceptability of generated question, given the context and answer. Args: generated_questions (list of str): Generated/candidate questions. contexts (list of str): List of contexts. answers (list of str): List of reference answers. qa_model (str): Path to the QA model (local path or HF model hub), default: 'allenai/unifiedqa-v2-t5-large-1363200' sp_model (str): Path of span scorer model (local path or HF model hub), default: 'alirezamsh/quip-512-mocha' verbose (bool): Turn on intermediate status update. device (str): On which the contextual embedding model will be allocated on. If this argument is None, the model lives on cuda:0 if cuda is available. nthreads (int): Number of threads. batch_size (int): Bert score processing batch size, at least one of `model_type` or `lang`. `lang` needs to be specified when `rescale_with_baseline` is True. Returns: score: RQUGE score. Examples: >>> generated_questions = ["how is the weather?"] >>> contexts = ["the weather is sunny"] >>> answers = ["sunny"] >>> rqugescore = evaluate.load("rquge") >>> results = rquge.compute(generated_questions=generated_questions, contexts=contexts, answers=answers) >>> print([round(v, 2) for v in results["score"]]) [5.0] """ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class RQUGEScore(evaluate.Metric): def _info(self): return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, homepage="https://github.com/alirezamshi/RQUGE", inputs_description=_KWARGS_DESCRIPTION, features=[ datasets.Features( { "generated_questions": datasets.Value("string", id="sequence"), "contexts": datasets.Value("string", id="sequence"), "answers": datasets.Value("string", id="sequence"), } ), ], codebase_urls=["https://github.com/alirezamshi/RQUGE"], reference_urls=[ "https://github.com/alirezamshi/RQUGE", "https://arxiv.org/abs/2211.01482", ], ) def _compute( self, generated_questions, contexts, answers, qa_model="allenai/unifiedqa-v2-t5-large-1363200", sp_model="alirezamsh/quip-512-mocha", verbose=False, device='cpu', ): rquge_model = RQUGE(sp_scorer_path=sp_model,qa_model_path=qa_model,device=device) output = [] total = 0 for context, question, answer in zip(contexts, generated_questions, answers): score = rquge_model.scorer(context, question, answer) total += score output.append(score) if verbose: print(f'Average RQUGE score is {total/len(output)}') output_dict = { "mean_score": total/len(output), "instance_score": output, } return output_dict