File size: 2,277 Bytes
a5bbcdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import re
from typing import List, Tuple
import pathlib

import torch
from transformers import BertTokenizer

from utils.sentence_retrieval_model import sentence_retrieval_model


THIS_DIR = pathlib.Path(__file__).parent.absolute()
ARGS = {
    'batch_size': 32,
    'bert_pretrain': 'base/bert_base',
    'checkpoint': 'base/model.best.32.pt',
    'dropout': 0.6,
    'bert_hidden_dim': 768,
    'max_len': 384,
    'cuda': torch.cuda.is_available()
}

if not ARGS['cuda']:
    print('CUDA NOT AVAILABLE')


def process_sent(sentence):
    sentence = re.sub("LSB.*?RSB", "", sentence)
    sentence = re.sub("LRB\s*?RRB", "", sentence)
    sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
    sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
    sentence = re.sub("--", "-", sentence)
    sentence = re.sub("``", '"', sentence)
    sentence = re.sub("''", '"', sentence)    
    return sentence

class SentenceRetrievalModule():

    def __init__(self, max_len=None):
        
        if max_len:
            ARGS['max_len'] = max_len
        
        self.tokenizer = BertTokenizer.from_pretrained(ARGS['bert_pretrain'], do_lower_case=False)
        self.model = sentence_retrieval_model(ARGS)
        self.model.load_state_dict(torch.load(ARGS['checkpoint'], map_location=torch.device('cpu'))['model'])
        if ARGS['cuda']:
            self.model = self.model.cuda()

    def score_sentence_pairs(self, inputs: List[Tuple[str]]):
        inputs_processed = [(process_sent(input[0]), process_sent(input[1])) for input in inputs]

        encodings =  self.tokenizer(
            inputs_processed,
            padding='max_length',
            truncation='longest_first',
            max_length=ARGS['max_len'],
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        inp = encodings['input_ids']
        msk = encodings['attention_mask']
        seg = encodings['token_type_ids']

        if ARGS['cuda']:
            inp = inp.cuda()
            msk = msk.cuda()
            seg = seg.cuda()

        self.model.eval()
        with torch.no_grad():
            outputs = self.model(inp, msk, seg).tolist()

        assert len(outputs) == len(inputs)

        return outputs