Prove_KCL / utils /textual_entailment_module.py
Jongmo's picture
Upload 25 files
a5bbcdb verified
raw
history blame contribute delete
No virus
2.91 kB
import json
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import re
from transformers import BertTokenizer, BertForSequenceClassification
# Constants and paths
HOME = Path('/users/k2031554')
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
MAX_LEN = 512
CLASSES = ['SUPPORTS','REFUTES','NOT ENOUGH INFO']
METHODS = ['WEIGHTED_SUM', 'MALON']
def process_sent(sentence):
sentence = re.sub("LSB.*?RSB", "", sentence)
sentence = re.sub("LRB\s*?RRB", "", sentence)
sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
sentence = re.sub("--", "-", sentence)
sentence = re.sub("``", '"', sentence)
sentence = re.sub("''", '"', sentence)
return sentence
class TextualEntailmentModule():
def __init__(
self,
model_path = 'base/models/BERT_FEVER_v4_model_PBT',
tokenizer_path = 'base/models/BERT_FEVER_v4_tok_PBT'
):
self.tokenizer = BertTokenizer.from_pretrained(
tokenizer_path
)
self.model = BertForSequenceClassification.from_pretrained(
model_path
)
self.model.to(DEVICE)
#def get_pair_scores(self, claim, evidence):
#
# encodings = self.tokenizer(
# [claim, evidence],
# max_length= MAX_LEN,
# return_token_type_ids=False,
# padding='max_length',
# truncation=True,
# return_tensors='pt',
# ).to(DEVICE)
#
# self.model.eval()
# with torch.no_grad():
# probs = self.model(
# input_ids=encodings['input_ids'],
# attention_mask=encodings['attention_mask']
# )
#
# return torch.softmax(probs.logits,dim=1).cpu().numpy()
def get_batch_scores(self, claims, evidence):
inputs = list(zip(claims, evidence))
encodings = self.tokenizer(
inputs,
max_length= MAX_LEN,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_tensors='pt',
).to(DEVICE)
self.model.eval()
with torch.no_grad():
probs = self.model(
input_ids=encodings['input_ids'],
attention_mask=encodings['attention_mask']
)
return torch.softmax(probs.logits,dim=1).cpu().numpy()
def get_label_from_scores(self, scores):
return CLASSES[np.argmax(scores)]
def get_label_malon(self, score_set):
score_labels = [np.argmax(s) for s in score_set]
if 1 not in score_labels and 0 not in score_labels:
return CLASSES[2] #NOT ENOUGH INFO
elif 0 in score_labels:
return CLASSES[0] #SUPPORTS
elif 1 in score_labels:
return CLASSES[1] #REFUTES