Hongxuan Li commited on
Commit
d21a961
1 Parent(s): 9b73c65

add handler

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. handler.py +37 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import NougatProcessor, VisionEncoderDecoderModel
2
+ import torch.cuda
3
+ import io
4
+ import base64
5
+ from PIL import Image
6
+ from typing import Dict, Any
7
+
8
+ class EndpointHandler():
9
+ def __init__(self, path="facebook/nougat-base"):
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ self.processor = NougatProcessor.from_pretrained(path)
12
+ self.model = VisionEncoderDecoderModel.from_pretrained(path)
13
+ self.model = model.to(self.device)
14
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
15
+ """
16
+ Args:
17
+ data (Dict): The payload with the text prompt
18
+ and generation parameters.
19
+ """
20
+ # Get inputs
21
+ input = data.pop("inputs", None)
22
+ parameters = data.pop("parameters", None)
23
+ fix_markdown = data.pop("fix_markdown", None)
24
+ if input is None:
25
+ raise ValueError("Missing image.")
26
+ # autoregressively generate tokens, with custom stopping criteria (as defined by the Nougat authors)
27
+ binary_data = base64.b64decode(input)
28
+
29
+ image = Image.open(io.BytesIO(binary_data))
30
+ pixel_values = self.processor(images= image, return_tensors="pt").pixel_values
31
+ outputs = self.model.generate(inputs = pixel_values.to(self.device),
32
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
33
+ **parameters)
34
+ generated = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
35
+ prediction = self.processor.post_process_generation(generated, fix_markdown=fix_markdown)
36
+
37
+ return {"generated_text": prediction}