File size: 2,277 Bytes
a5bbcdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import re
from typing import List, Tuple
import pathlib
import torch
from transformers import BertTokenizer
from utils.sentence_retrieval_model import sentence_retrieval_model
THIS_DIR = pathlib.Path(__file__).parent.absolute()
ARGS = {
'batch_size': 32,
'bert_pretrain': 'base/bert_base',
'checkpoint': 'base/model.best.32.pt',
'dropout': 0.6,
'bert_hidden_dim': 768,
'max_len': 384,
'cuda': torch.cuda.is_available()
}
if not ARGS['cuda']:
print('CUDA NOT AVAILABLE')
def process_sent(sentence):
sentence = re.sub("LSB.*?RSB", "", sentence)
sentence = re.sub("LRB\s*?RRB", "", sentence)
sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
sentence = re.sub("--", "-", sentence)
sentence = re.sub("``", '"', sentence)
sentence = re.sub("''", '"', sentence)
return sentence
class SentenceRetrievalModule():
def __init__(self, max_len=None):
if max_len:
ARGS['max_len'] = max_len
self.tokenizer = BertTokenizer.from_pretrained(ARGS['bert_pretrain'], do_lower_case=False)
self.model = sentence_retrieval_model(ARGS)
self.model.load_state_dict(torch.load(ARGS['checkpoint'], map_location=torch.device('cpu'))['model'])
if ARGS['cuda']:
self.model = self.model.cuda()
def score_sentence_pairs(self, inputs: List[Tuple[str]]):
inputs_processed = [(process_sent(input[0]), process_sent(input[1])) for input in inputs]
encodings = self.tokenizer(
inputs_processed,
padding='max_length',
truncation='longest_first',
max_length=ARGS['max_len'],
return_token_type_ids=True,
return_attention_mask=True,
return_tensors='pt',
)
inp = encodings['input_ids']
msk = encodings['attention_mask']
seg = encodings['token_type_ids']
if ARGS['cuda']:
inp = inp.cuda()
msk = msk.cuda()
seg = seg.cuda()
self.model.eval()
with torch.no_grad():
outputs = self.model(inp, msk, seg).tolist()
assert len(outputs) == len(inputs)
return outputs |