File size: 2,913 Bytes
a5bbcdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import json
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import re

from transformers import BertTokenizer, BertForSequenceClassification

# Constants and paths
HOME = Path('/users/k2031554')
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
MAX_LEN = 512
CLASSES = ['SUPPORTS','REFUTES','NOT ENOUGH INFO']
METHODS = ['WEIGHTED_SUM', 'MALON']

def process_sent(sentence):
    sentence = re.sub("LSB.*?RSB", "", sentence)
    sentence = re.sub("LRB\s*?RRB", "", sentence)
    sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
    sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
    sentence = re.sub("--", "-", sentence)
    sentence = re.sub("``", '"', sentence)
    sentence = re.sub("''", '"', sentence)    
    return sentence

class TextualEntailmentModule():

    def __init__(
        self,
        model_path = 'base/models/BERT_FEVER_v4_model_PBT',
        tokenizer_path = 'base/models/BERT_FEVER_v4_tok_PBT'
        ):
        self.tokenizer = BertTokenizer.from_pretrained(
            tokenizer_path
        )
        self.model = BertForSequenceClassification.from_pretrained(
            model_path
        )
        self.model.to(DEVICE)

    #def get_pair_scores(self, claim, evidence):
    #    
    #    encodings = self.tokenizer(
    #        [claim, evidence],
    #        max_length= MAX_LEN,
    #        return_token_type_ids=False,
    #        padding='max_length',
    #        truncation=True,
    #        return_tensors='pt',
    #    ).to(DEVICE)
    #
    #    self.model.eval()
    #    with torch.no_grad():
    #        probs = self.model(
    #            input_ids=encodings['input_ids'],
    #            attention_mask=encodings['attention_mask']
    #        )
    #    
    #    return torch.softmax(probs.logits,dim=1).cpu().numpy()

    def get_batch_scores(self, claims, evidence):

        inputs = list(zip(claims, evidence))
        
        encodings = self.tokenizer(
            inputs,
            max_length= MAX_LEN,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        ).to(DEVICE)

        self.model.eval()
        with torch.no_grad():
            probs = self.model(
                input_ids=encodings['input_ids'],
                attention_mask=encodings['attention_mask']
            )
        
        return torch.softmax(probs.logits,dim=1).cpu().numpy()

    def get_label_from_scores(self, scores):
        return CLASSES[np.argmax(scores)]

    def get_label_malon(self, score_set):
        score_labels = [np.argmax(s) for s in score_set]
        if 1 not in score_labels and 0 not in score_labels:
            return CLASSES[2] #NOT ENOUGH INFO
        elif 0 in score_labels:
            return CLASSES[0] #SUPPORTS
        elif 1 in score_labels:
            return CLASSES[1] #REFUTES