linh-truong's picture
init
c1b4f26
raw
history blame
No virus
2.8 kB
from paddleocr import PaddleOCR
from vietocr.tool.config import Cfg
from vietocr.tool.predictor import Predictor
from utils.config import Config
import requests
import numpy as np
from PIL import Image, ImageTransform
class OCRDetector:
def __init__(self) -> None:
self.paddle_ocr = PaddleOCR(lang='en',
use_angle_cls=False,
use_gpu=True if Config.device == "cpu" else False,
show_log=False )
# config['weights'] = './weights/transformerocr.pth'
vietocr_config = Cfg.load_config_from_name('vgg_transformer')
vietocr_config['weights'] = Config.ocr_path
vietocr_config['cnn']['pretrained']=False
vietocr_config['device'] = Config.device
vietocr_config['predictor']['beamsearch']=False
self.viet_ocr = Predictor(vietocr_config)
def find_box(self, image):
'''Xác định box dựa vào mô hình paddle_ocr'''
result = self.paddle_ocr.ocr(image, cls = False, rec=False)
result = result[0]
# Extracting detected components
boxes = result #[res[0] for res in result]
boxes = np.array(boxes).astype(int)
# scores = [res[1][1] for res in result]
return boxes
def cut_image_polygon(self, image, box):
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = box
w = x2 - x1
h = y4 - y1
scl = h//7
new_box = [max(x1-scl,0), max(y1 - scl, 0)], [x2+scl, y2-scl], [x3+scl, y3+scl], [x4-scl, y4+scl]
(x1, y1), (x2, y2), (x3, y3), (x4, y4) = new_box
# Define 8-tuple with x,y coordinates of top-left, bottom-left, bottom-right and top-right corners and apply
transform = [x1, y1, x4, y4, x3, y3, x2, y2]
result = image.transform((w,h), ImageTransform.QuadTransform(transform))
return result
def vietnamese_text(self, boxes, image):
'''Xác định text dựa vào mô hình viet_ocr'''
results = []
for box in boxes:
try:
cut_image = self.cut_image_polygon(image, box)
# cut_image = Image.fromarray(np.uint8(cut_image))
text, score = self.viet_ocr.predict(cut_image, return_prob=True)
if score > Config.vietocr_threshold:
results.append({"text": text,
"score": score,
"box": box})
except:
continue
return results
#Merge
def text_detector(self, image_path):
if image_path.startswith("https://"):
image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
else:
image = Image.open(image_path).convert("RGB")
# np_image = np.array(image)
boxes = self.find_box(image_path)
if not boxes.any():
return None
results = self.vietnamese_text(boxes, image)
if results != []:
return results
else:
return None