shramay-palta commited on
Commit
aa20953
1 Parent(s): 22b944c

Create demo_t5_qa_pipe.py

Browse files
Files changed (1) hide show
  1. demo_t5_qa_pipe.py +39 -0
demo_t5_qa_pipe.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from transformers import Text2TextGenerationPipeline
5
+
6
+ class DemoT5QAPipeline(Text2TextGenerationPipeline):
7
+ def _forward(self, model_inputs, **generate_kwargs):
8
+ if self.framework == "pt":
9
+ in_b, input_length = model_inputs["input_ids"].shape
10
+ elif self.framework == "tf":
11
+ in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
12
+
13
+ self.check_inputs(
14
+ input_length,
15
+ generate_kwargs.get("min_length", self.model.config.min_length),
16
+ generate_kwargs.get("max_length", self.model.config.max_length),
17
+ )
18
+ outputs = self.model.generate(**model_inputs, **generate_kwargs, return_dict_in_generate=True, output_scores=True, max_new_tokens=75)
19
+
20
+ # Code from the parent class
21
+ output_ids = outputs.sequences
22
+ out_b = output_ids.shape[0]
23
+ if self.framework == "pt":
24
+ output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
25
+ elif self.framework == "tf":
26
+ output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
27
+
28
+ output_sequences = outputs.sequences
29
+ output_scores = outputs.scores
30
+ return {"output_ids": output_ids, "output_sequences": output_sequences, "output_scores": output_scores}
31
+
32
+ def postprocess(self, model_outputs):
33
+ guess_text = super().postprocess(model_outputs)[0]['generated_text']
34
+
35
+ transition_scores = self.model.compute_transition_scores(model_outputs['output_sequences'], model_outputs['output_scores'], normalize_logits=True)
36
+ log_probs = np.round(np.exp(transition_scores.cpu().numpy()), 3)[0]
37
+ guess_prob = np.product(log_probs)
38
+
39
+ return {'guess': guess_text, 'confidence': guess_prob}