# TODO: BEFORE RUNNING: pip install git+https://github.com/gaussalgo/adaptor.git@QA_to_objectives from adaptor.objectives.question_answering import ExtractiveQA import json from adaptor.adapter import Adapter from adaptor.evaluators.question_answering import BLEUForQA from adaptor.lang_module import LangModule from adaptor.schedules import ParallelSchedule from adaptor.utils import AdaptationArguments, StoppingStrategy # custom classes from datasets import load_dataset model_name = "bert-base-multilingual-cased" lang_module = LangModule(model_name) training_arguments = AdaptationArguments(output_dir="train_dir", learning_rate=4e-5, stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, do_train=True, do_eval=True, warmup_steps=1000, max_steps=100000, gradient_accumulation_steps=1, eval_steps=1, logging_steps=10, save_steps=1000, num_train_epochs=30, evaluation_strategy="steps") val_metrics = [BLEUForQA(decides_convergence=True)] # get eval and train dataset squad_dataset = json.load(open("data/czech_squad.json")) questions = [] contexts = [] answers = [] skipped = 0 for i, entry in squad_dataset.items(): if entry["answers"]["text"][0] in entry["context"]: # and len(entry["context"]) < 1024: # these are characters, will be automatically truncated from input anyway questions.append(entry["question"]) contexts.append(entry["context"]) answers.append(entry["answers"]["text"][0]) else: skipped += 1 print("Skipped examples from SQuAD-cs: %s" % skipped) train_questions = questions[:-200] val_questions = questions[-200:] train_answers = answers[:-200] val_answers = answers[-200:] train_context = contexts[:-200] val_context = contexts[-200:] # declaration of extractive question answering objective generative_qa_cs = ExtractiveQA(lang_module, texts_or_path=train_questions, text_pair_or_path=train_context, labels_or_path=train_answers, val_texts_or_path=val_questions, val_text_pair_or_path=val_context, val_labels_or_path=val_answers, batch_size=1, val_evaluators=val_metrics, objective_id="SQUAD-cs") # english SQuAD squad_en = load_dataset("squad") squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000) train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"], squad_train["context"])] val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"], squad_en["validation"]["context"])] train_answers_en = [a["text"][0] for a in squad_train["answers"]] val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]] generative_qa_en = ExtractiveQA(lang_module, texts_or_path=squad_train["question"], text_pair_or_path=squad_train["context"], labels_or_path=[a["text"][0] for a in squad_train["answers"]], val_texts_or_path=squad_en["validation"]["question"][:200], val_text_pair_or_path=squad_en["validation"]["context"][:200], val_labels_or_path=[a["text"][0] for a in squad_en["validation"]["answers"]][:200], batch_size=10, val_evaluators=val_metrics, objective_id="SQUAD-en") schedule = ParallelSchedule(objectives=[generative_qa_cs, generative_qa_en], args=training_arguments) adapter = Adapter(lang_module, schedule, args=training_arguments) adapter.train()