Spaces:
Runtime error
Runtime error
from transformers import Pipeline | |
from typing import Dict, Any, Union | |
from transformers.pipelines.base import GenericTensor | |
from transformers.modeling_outputs import ModelOutput | |
import torch | |
class NERPredictorPipe(Pipeline): | |
def _sanitize_parameters(self, **kwargs): | |
return {},{},{} | |
def __token_preprocess(self, input, tokenizer, max_length=512): | |
tokenized = tokenizer(input, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_tensors="pt" | |
) | |
return tokenized | |
def preprocess(self, sentence: Union[str,list], max_length=512) -> Dict[str, GenericTensor]: | |
input_tensors = self.__token_preprocess( | |
sentence, | |
self.tokenizer, | |
max_length=max_length | |
) | |
input_tensors["input_mask"] = (~(input_tensors["input_ids"]>0)).long() | |
for key in input_tensors: | |
if input_tensors[key] is not None: | |
input_tensors[key] = input_tensors[key].to(self.device) | |
return input_tensors | |
def _forward(self, input_tensors: Dict[str, GenericTensor]) -> ModelOutput: | |
self.model.eval() | |
with torch.no_grad(): | |
_,(best_path,_) = self.model(**input_tensors) | |
return (input_tensors["input_ids"].tolist(),best_path) | |
def __format_output(self, start, end, text, label): | |
return { | |
"text": text, | |
"start": start, | |
"end": end, | |
"label": label | |
} | |
def postprocess(self, model_outputs: ModelOutput) -> Any: | |
batch_slices = [] | |
input_ids_list = model_outputs[0] | |
label_ids_list = model_outputs[1] | |
for input_ids,label_ids in zip(input_ids_list,label_ids_list): | |
slices = [] | |
labels = list(self.model.config.id2tag[str(id)] for id in label_ids) | |
# get slice | |
past = "O" | |
start = -1 | |
end = -1 | |
for i,label in enumerate(labels): | |
if label.startswith("B-"): | |
if start!=-1 and end!=-1: | |
slices.append( | |
self.__format_output( | |
start, end, | |
''.join(self.tokenizer.convert_ids_to_tokens( | |
input_ids[start+1:end+2])), past | |
) | |
) | |
start = i | |
end = i | |
past = "-".join(label.split("-")[1:]) | |
elif label.startswith("I-") or label.startswith("M-") or label.startswith("E-"): | |
cur = "-".join(label.split("-")[1:]) | |
if cur!=past: | |
# cut and skip to next entity | |
if start!=-1 and end!=-1: | |
slices.append( | |
self.__format_output( | |
start, end, | |
''.join(self.tokenizer.convert_ids_to_tokens( | |
input_ids[start+1:end+2])), past | |
) | |
) | |
start = i | |
past = cur | |
end = i | |
elif label.startswith("S-"): | |
if start!=-1 and end!=-1: | |
slices.append( | |
self.__format_output( | |
start, end, | |
''.join(self.tokenizer.convert_ids_to_tokens( | |
input_ids[start+1:end+2])), past | |
) | |
) | |
slices.append( | |
self.__format_output( | |
i, i, | |
''.join(self.tokenizer.convert_ids_to_tokens( | |
input_ids[i+1:i+2])), past | |
) | |
) | |
start = -1 | |
end = -1 | |
past = "O" | |
if start!=-1 and end!=-1: | |
slices.append( | |
self.__format_output( | |
start, end, | |
''.join(self.tokenizer.convert_ids_to_tokens( | |
input_ids[start+1:end+2])), past | |
) | |
) | |
batch_slices.append(slices) | |
return batch_slices |