minskiter's picture
feat(models): update models and deploy app.py
9f86c43
raw
history blame
No virus
4.55 kB
from transformers import Pipeline
from typing import Dict, Any, Union
from transformers.pipelines.base import GenericTensor
from transformers.modeling_outputs import ModelOutput
import torch
class NERPredictorPipe(Pipeline):
def _sanitize_parameters(self, **kwargs):
return {},{},{}
def __token_preprocess(self, input, tokenizer, max_length=512):
tokenized = tokenizer(input,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt"
)
return tokenized
def preprocess(self, sentence: Union[str,list], max_length=512) -> Dict[str, GenericTensor]:
input_tensors = self.__token_preprocess(
sentence,
self.tokenizer,
max_length=max_length
)
input_tensors["input_mask"] = (~(input_tensors["input_ids"]>0)).long()
for key in input_tensors:
if input_tensors[key] is not None:
input_tensors[key] = input_tensors[key].to(self.device)
return input_tensors
def _forward(self, input_tensors: Dict[str, GenericTensor]) -> ModelOutput:
self.model.eval()
with torch.no_grad():
_,(best_path,_) = self.model(**input_tensors)
return (input_tensors["input_ids"].tolist(),best_path)
def __format_output(self, start, end, text, label):
return {
"text": text,
"start": start,
"end": end,
"label": label
}
def postprocess(self, model_outputs: ModelOutput) -> Any:
batch_slices = []
input_ids_list = model_outputs[0]
label_ids_list = model_outputs[1]
for input_ids,label_ids in zip(input_ids_list,label_ids_list):
slices = []
labels = list(self.model.config.id2tag[str(id)] for id in label_ids)
# get slice
past = "O"
start = -1
end = -1
for i,label in enumerate(labels):
if label.startswith("B-"):
if start!=-1 and end!=-1:
slices.append(
self.__format_output(
start, end,
''.join(self.tokenizer.convert_ids_to_tokens(
input_ids[start+1:end+2])), past
)
)
start = i
end = i
past = "-".join(label.split("-")[1:])
elif label.startswith("I-") or label.startswith("M-") or label.startswith("E-"):
cur = "-".join(label.split("-")[1:])
if cur!=past:
# cut and skip to next entity
if start!=-1 and end!=-1:
slices.append(
self.__format_output(
start, end,
''.join(self.tokenizer.convert_ids_to_tokens(
input_ids[start+1:end+2])), past
)
)
start = i
past = cur
end = i
elif label.startswith("S-"):
if start!=-1 and end!=-1:
slices.append(
self.__format_output(
start, end,
''.join(self.tokenizer.convert_ids_to_tokens(
input_ids[start+1:end+2])), past
)
)
slices.append(
self.__format_output(
i, i,
''.join(self.tokenizer.convert_ids_to_tokens(
input_ids[i+1:i+2])), past
)
)
start = -1
end = -1
past = "O"
if start!=-1 and end!=-1:
slices.append(
self.__format_output(
start, end,
''.join(self.tokenizer.convert_ids_to_tokens(
input_ids[start+1:end+2])), past
)
)
batch_slices.append(slices)
return batch_slices