from transformers import Pipeline from typing import Dict, Any, Union from transformers.pipelines.base import GenericTensor from transformers.modeling_outputs import ModelOutput import torch class NERPredictorPipe(Pipeline): def _sanitize_parameters(self, **kwargs): return {},{},{} def __token_preprocess(self, input, tokenizer, max_length=512): tokenized = tokenizer(input, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt" ) return tokenized def preprocess(self, sentence: Union[str,list], max_length=512) -> Dict[str, GenericTensor]: input_tensors = self.__token_preprocess( sentence, self.tokenizer, max_length=max_length ) input_tensors["input_mask"] = (~(input_tensors["input_ids"]>0)).long() for key in input_tensors: if input_tensors[key] is not None: input_tensors[key] = input_tensors[key].to(self.device) return input_tensors def _forward(self, input_tensors: Dict[str, GenericTensor]) -> ModelOutput: self.model.eval() with torch.no_grad(): _,(best_path,_) = self.model(**input_tensors) return (input_tensors["input_ids"].tolist(),best_path) def __format_output(self, start, end, text, label): return { "text": text, "start": start, "end": end, "label": label } def postprocess(self, model_outputs: ModelOutput) -> Any: batch_slices = [] input_ids_list = model_outputs[0] label_ids_list = model_outputs[1] for input_ids,label_ids in zip(input_ids_list,label_ids_list): slices = [] labels = list(self.model.config.id2tag[str(id)] for id in label_ids) # get slice past = "O" start = -1 end = -1 for i,label in enumerate(labels): if label.startswith("B-"): if start!=-1 and end!=-1: slices.append( self.__format_output( start, end, ''.join(self.tokenizer.convert_ids_to_tokens( input_ids[start+1:end+2])), past ) ) start = i end = i past = "-".join(label.split("-")[1:]) elif label.startswith("I-") or label.startswith("M-") or label.startswith("E-"): cur = "-".join(label.split("-")[1:]) if cur!=past: # cut and skip to next entity if start!=-1 and end!=-1: slices.append( self.__format_output( start, end, ''.join(self.tokenizer.convert_ids_to_tokens( input_ids[start+1:end+2])), past ) ) start = i past = cur end = i elif label.startswith("S-"): if start!=-1 and end!=-1: slices.append( self.__format_output( start, end, ''.join(self.tokenizer.convert_ids_to_tokens( input_ids[start+1:end+2])), past ) ) slices.append( self.__format_output( i, i, ''.join(self.tokenizer.convert_ids_to_tokens( input_ids[i+1:i+2])), past ) ) start = -1 end = -1 past = "O" if start!=-1 and end!=-1: slices.append( self.__format_output( start, end, ''.join(self.tokenizer.convert_ids_to_tokens( input_ids[start+1:end+2])), past ) ) batch_slices.append(slices) return batch_slices