File size: 6,183 Bytes
c1b4f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

import torch
import requests
from PIL import Image, ImageFont, ImageDraw, ImageTransform
from transformers import AutoImageProcessor, ViTModel, AutoTokenizer, T5EncoderModel
from utils.config import Config
from src.ocr import OCRDetector


class ViT:
    def __init__(self) -> None:
        self.processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
        self.model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.model.to(Config.device)

    def extraction(self, image_url):
        if image_url.startswith("https://"):
            images = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
        else:
            images = Image.open(image_url).convert("RGB")

        inputs = self.processor(images, return_tensors="pt").to(Config.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        attention_mask = torch.ones((last_hidden_states.shape[0], last_hidden_states.shape[1]))

        return last_hidden_states.to(Config.device), attention_mask.to(Config.device)

    def pooling_extraction(self, image):
        image_inputs = self.processor(image, return_tensors="pt").to(Config.device)

        with torch.no_grad():
            image_outputs = self.model(**image_inputs)
            image_pooler_output = image_outputs.pooler_output
            image_pooler_output = torch.unsqueeze(image_pooler_output, 0)
            image_attention_mask = torch.ones((image_pooler_output.shape[0], image_pooler_output.shape[1]))

        return image_pooler_output.to(Config.device), image_attention_mask.to(Config.device)

class OCR:
    def __init__(self) -> None:
        self.ocr_detector = OCRDetector()

    def extraction(self, image_dir):

        ocr_results = self.ocr_detector.text_detector(image_dir)
        if not ocr_results:
            print("NOT OCR1")

            return "", [], []

        ocrs = self.post_process(ocr_results)

        if not ocrs:

            return "", [], []

        ocrs.reverse()

        boxes = []
        texts = []
        for idx, ocr in enumerate(ocrs):
            boxes.append(ocr["box"])
            texts.append(ocr["text"])

        groups_box, groups_text, paragraph_boxes = OCR.group_boxes(boxes, texts)
        for temp in groups_text:
            print("OCR: ", temp)

        texts = [" ".join(group_text) for group_text in groups_text]
        ocr_content = "<extra_id_0>".join(texts)
        ocr_content = ocr_content.lower()
        ocr_content = " ".join(ocr_content.split())
        ocr_content = "<extra_id_0>" + ocr_content


        return ocr_content, groups_box, paragraph_boxes

    def post_process(self,ocr_results):
        ocrs = []
        for result in ocr_results:
            text = result["text"]
            # if len(text) <=2:
            #   continue
            # if len(set(text.replace(" ", ""))) <=2:
            #   continue
            box = result["box"]

            # (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
            # w = x2 - x1
            # h = y4 - y1
            # if h > w:
            #   continue

            # if w*h < 300:
            #   continue

            ocrs.append(
                {"text": text.lower(),
                "box": box}
            )
        return ocrs

    @staticmethod
    def cut_image_polygon(image, box):
        (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
        w = x2 - x1
        h = y4 - y1
        scl = h//7
        new_box = [max(x1-scl,0), max(y1 - scl, 0)], [x2+scl, y2-scl], [x3+scl, y3+scl], [x4-scl, y4+scl]
        (x1, y1), (x2, y2), (x3, y3), (x4, y4) = new_box
        # Define 8-tuple with x,y coordinates of top-left, bottom-left, bottom-right and top-right corners and apply
        transform = [x1, y1, x4, y4, x3, y3, x2, y2]
        result = image.transform((w,h), ImageTransform.QuadTransform(transform))
        return result


    @staticmethod
    def check_point_in_rectangle(box, point, padding_devide):
      (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
      x_min = min(x1, x4)
      x_max = max(x2, x3)

      padding = (x_max-x_min)//padding_devide
      x_min = x_min - padding
      x_max = x_max + padding

      y_min = min(y1, y2)
      y_max = max(y3, y4)

      y_min = y_min - padding
      y_max = y_max + padding

      x, y = point

      if x >= x_min and x <= x_max and y >= y_min and y <= y_max:
        return True

      return False

    @staticmethod
    def check_rectangle_overlap(rec1, rec2, padding_devide):
      for point in rec1:
        if OCR.check_point_in_rectangle(rec2, point, padding_devide):
          return True

      for point in rec2:
        if OCR.check_point_in_rectangle(rec1, point, padding_devide):
          return True

      return False

    @staticmethod
    def group_boxes(boxes, texts):
      groups = []
      groups_text = []
      paragraph_boxes = []
      processed = []
      boxes_cp = boxes.copy()
      for i, (box, text) in enumerate(zip(boxes_cp, texts)):
        (x1, y1), (x2, y2), (x3, y3), (x4, y4) = box

        if i not in processed:
          processed.append(i)
        else:
          continue

        groups.append([box])
        groups_text.append([text])
        for j, (box2, text2) in enumerate(zip(boxes_cp[i+1:], texts[i+1:])):
          if j+i+1 in processed:
            continue
          padding_devide = len(groups[-1])*4
          is_overlap = OCR.check_rectangle_overlap(box, box2, padding_devide)
          if is_overlap:
            (xx1, yy1), (xx2, yy2), (xx3, yy3), (xx4, yy4) = box2
            processed.append(j+i+1)
            groups[-1].append(box2)
            groups_text[-1].append(text2)
            new_x1 = min(x1, xx1)
            new_y1 = min(y1, yy1)
            new_x2 = max(x2, xx2)
            new_y2 = min(y2, yy2)
            new_x3 = max(x3, xx3)
            new_y3 = max(y3, yy3)
            new_x4 = min(x4, xx4)
            new_y4 = max(y4, yy4)

            box = [(new_x1, new_y1), (new_x2, new_y2), (new_x3, new_y3), (new_x4, new_y4)]

        paragraph_boxes.append(box)
      return groups, groups_text, paragraph_boxes