File size: 4,565 Bytes
36e28a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# TODO: BEFORE RUNNING: pip install git+https://github.com/gaussalgo/adaptor.git@QA_to_objectives

from adaptor.objectives.question_answering import ExtractiveQA
import json

from adaptor.adapter import Adapter
from adaptor.evaluators.question_answering import BLEUForQA
from adaptor.lang_module import LangModule
from adaptor.schedules import ParallelSchedule
from adaptor.utils import AdaptationArguments, StoppingStrategy

# custom classes
from datasets import load_dataset

model_name = "bert-base-multilingual-cased"

lang_module = LangModule(model_name)

training_arguments = AdaptationArguments(output_dir="train_dir",
                                         learning_rate=4e-5,
                                         stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
                                         do_train=True,
                                         do_eval=True,
                                         warmup_steps=1000,
                                         max_steps=100000,
                                         gradient_accumulation_steps=1,
                                         eval_steps=1,
                                         logging_steps=10,
                                         save_steps=1000,
                                         num_train_epochs=30,
                                         evaluation_strategy="steps")

val_metrics = [BLEUForQA(decides_convergence=True)]

# get eval and train dataset
squad_dataset = json.load(open("data/czech_squad.json"))
questions = []
contexts = []
answers = []
skipped = 0

for i, entry in squad_dataset.items():
    if entry["answers"]["text"][0] in entry["context"]:
        # and len(entry["context"]) < 1024:  # these are characters, will be automatically truncated from input anyway
        questions.append(entry["question"])
        contexts.append(entry["context"])
        answers.append(entry["answers"]["text"][0])
    else:
        skipped += 1

print("Skipped examples from SQuAD-cs: %s" % skipped)

train_questions = questions[:-200]
val_questions = questions[-200:]

train_answers = answers[:-200]
val_answers = answers[-200:]

train_context = contexts[:-200]
val_context = contexts[-200:]

# declaration of extractive question answering objective
generative_qa_cs = ExtractiveQA(lang_module,
                                texts_or_path=train_questions,
                                text_pair_or_path=train_context,
                                labels_or_path=train_answers,
                                val_texts_or_path=val_questions,
                                val_text_pair_or_path=val_context,
                                val_labels_or_path=val_answers,
                                batch_size=1,
                                val_evaluators=val_metrics,
                                objective_id="SQUAD-cs")

# english SQuAD
squad_en = load_dataset("squad")
squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000)

train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"],
                                                                                   squad_train["context"])]
val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"],
                                                                                 squad_en["validation"]["context"])]
train_answers_en = [a["text"][0] for a in squad_train["answers"]]
val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]]

generative_qa_en = ExtractiveQA(lang_module,
                                texts_or_path=squad_train["question"],
                                text_pair_or_path=squad_train["context"],
                                labels_or_path=[a["text"][0] for a in squad_train["answers"]],
                                val_texts_or_path=squad_en["validation"]["question"][:200],
                                val_text_pair_or_path=squad_en["validation"]["context"][:200],
                                val_labels_or_path=[a["text"][0] for a in squad_en["validation"]["answers"]][:200],
                                batch_size=10,
                                val_evaluators=val_metrics,
                                objective_id="SQUAD-en")

schedule = ParallelSchedule(objectives=[generative_qa_cs, generative_qa_en],
                            args=training_arguments)

adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()