File size: 849 Bytes
c96e0a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import Pipeline, T5ForConditionalGeneration, AutoTokenizer

class PersianTextFormalizerPipeline(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "second_text" in kwargs:
            preprocess_kwargs["second_text"] = kwargs["second_text"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, text, second_text=None):
        inputs = self.tokenizer.encode("informal: " + text, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
        return inputs.to(self.device)

    def _forward(self, model_inputs):
        return self.model.generate(model_inputs, max_length=128, num_beams=4, temperature=0.7)

    def postprocess(self, model_outputs):
        return self.tokenizer.decode(model_outputs[0], skip_special_tokens=True)