From c01cc08925888f9422a8657ee23a0193ac5c2e40 Mon Sep 17 00:00:00 2001 From: SWHL Date: Fri, 28 Jun 2024 09:40:50 +0800 Subject: [PATCH] refactor(rapidocr_paddle): Refactor the entire code --- .../rapidocr_paddle/ch_ppocr_v2_cls/utils.py | 18 +- .../ch_ppocr_v3_det/text_detect.py | 109 +++--- .../rapidocr_paddle/ch_ppocr_v3_det/utils.py | 309 +++++------------- .../ch_ppocr_v3_rec/text_recognize.py | 20 +- .../rapidocr_paddle/ch_ppocr_v3_rec/utils.py | 118 ++++--- python/rapidocr_paddle/main.py | 52 ++- python/rapidocr_paddle/utils/load_image.py | 4 +- .../rapidocr_paddle/utils/parse_parameters.py | 11 +- python/rapidocr_paddle/utils/vis_res.py | 4 +- 9 files changed, 257 insertions(+), 388 deletions(-) diff --git a/python/rapidocr_paddle/ch_ppocr_v2_cls/utils.py b/python/rapidocr_paddle/ch_ppocr_v2_cls/utils.py index 5c75d54ee..6549fcb2b 100644 --- a/python/rapidocr_paddle/ch_ppocr_v2_cls/utils.py +++ b/python/rapidocr_paddle/ch_ppocr_v2_cls/utils.py @@ -11,20 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -class ClsPostProcess: - """Convert between text-label and text-index""" +from typing import List, Tuple + +import numpy as np + - def __init__(self, label_list): - super(ClsPostProcess, self).__init__() +class ClsPostProcess: + def __init__(self, label_list: List[str]): self.label_list = label_list - def __call__(self, preds, label=None): + def __call__(self, preds: np.ndarray) -> List[Tuple[str, float]]: pred_idxs = preds.argmax(axis=1) decode_out = [ (self.label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs) ] - if label is None: - return decode_out - - label = [(self.label_list[idx], 1.0) for idx in label] - return decode_out, label + return decode_out diff --git a/python/rapidocr_paddle/ch_ppocr_v3_det/text_detect.py b/python/rapidocr_paddle/ch_ppocr_v3_det/text_detect.py index 433a0e9d5..7bdde5344 100644 --- a/python/rapidocr_paddle/ch_ppocr_v3_det/text_detect.py +++ b/python/rapidocr_paddle/ch_ppocr_v3_det/text_detect.py @@ -14,34 +14,21 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -import argparse import time +from typing import Any, Dict, Optional, Tuple -import cv2 import numpy as np -from rapidocr_paddle.utils import PaddleInferSession, read_yaml +from rapidocr_paddle.utils import PaddleInferSession -from .utils import DBPostProcess, create_operators, transform +from .utils import DBPostProcess, DetPreProcess class TextDetector: - def __init__(self, config): - pre_process_list = { - "DetResizeForTest": { - "limit_side_len": config.get("limit_side_len", 736), - "limit_type": config.get("limit_type", "min"), - }, - "NormalizeImage": { - "std": [0.229, 0.224, 0.225], - "mean": [0.485, 0.456, 0.406], - "scale": "1./255.", - "order": "hwc", - }, - "ToCHWImage": None, - "KeepKeys": {"keep_keys": ["image", "shape"]}, - } - self.preprocess_op = create_operators(pre_process_list) + def __init__(self, config: Dict[str, Any]): + limit_side_len = config.get("limit_side_len", 736) + limit_type = config.get("limit_type", "min") + self.preprocess_op = DetPreProcess(limit_side_len, limit_type) post_process = { "thresh": config.get("thresh", 0.3), @@ -55,31 +42,41 @@ def __init__(self, config): self.infer = PaddleInferSession(config) - def __call__(self, img): + def __call__(self, img: np.ndarray) -> Tuple[Optional[np.ndarray], float]: + start_time = time.perf_counter() + if img is None: raise ValueError("img is None") - ori_im_shape = img.shape[:2] - - data = {"image": img} - data = transform(data, self.preprocess_op) - img, shape_list = data - if img is None: + ori_img_shape = img.shape[0], img.shape[1] + prepro_img = self.preprocess_op(img) + if prepro_img is None: return None, 0 - img = np.expand_dims(img, axis=0).astype(np.float32) - shape_list = np.expand_dims(shape_list, axis=0) + preds = self.infer(prepro_img)[0] + dt_boxes, dt_boxes_scores = self.postprocess_op(preds, ori_img_shape) + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_img_shape) + elapse = time.perf_counter() - start_time + return dt_boxes, elapse - starttime = time.time() - preds = self.infer(img)[0] - post_result = self.postprocess_op(preds, shape_list) + def filter_tag_det_res( + self, dt_boxes: np.ndarray, image_shape: Tuple[int, int] + ) -> np.ndarray: + img_height, img_width = image_shape + dt_boxes_new = [] + for box in dt_boxes: + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) - dt_boxes = post_result[0]["points"] - dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im_shape) - elapse = time.time() - starttime - return dt_boxes, elapse + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + + dt_boxes_new.append(box) + return np.array(dt_boxes_new) - def order_points_clockwise(self, pts): + def order_points_clockwise(self, pts: np.ndarray) -> np.ndarray: """ reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py @@ -104,42 +101,10 @@ def order_points_clockwise(self, pts): rect = np.array([tl, tr, br, bl], dtype="float32") return rect - def clip_det_res(self, points, img_height, img_width): + def clip_det_res( + self, points: np.ndarray, img_height: int, img_width: int + ) -> np.ndarray: for pno in range(points.shape[0]): points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) return points - - def filter_tag_det_res(self, dt_boxes, image_shape): - img_height, img_width = image_shape[:2] - dt_boxes_new = [] - for box in dt_boxes: - box = self.order_points_clockwise(box) - box = self.clip_det_res(box, img_height, img_width) - rect_width = int(np.linalg.norm(box[0] - box[1])) - rect_height = int(np.linalg.norm(box[0] - box[3])) - if rect_width <= 3 or rect_height <= 3: - continue - dt_boxes_new.append(box) - dt_boxes = np.array(dt_boxes_new) - return dt_boxes - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--config_path", type=str, default="config.yaml") - parser.add_argument("--image_path", type=str, default=None) - args = parser.parse_args() - - config = read_yaml(args.config_path) - - text_detector = TextDetector(config) - - img = cv2.imread(args.image_path) - dt_boxes, elapse = text_detector(img) - - from utils import draw_text_det_res - - src_im = draw_text_det_res(dt_boxes, args.image_path) - cv2.imwrite("det_results.jpg", src_im) - print("The det_results.jpg has been saved in the current directory.") diff --git a/python/rapidocr_paddle/ch_ppocr_v3_det/utils.py b/python/rapidocr_paddle/ch_ppocr_v3_det/utils.py index e1155c1f6..b274abd4a 100644 --- a/python/rapidocr_paddle/ch_ppocr_v3_det/utils.py +++ b/python/rapidocr_paddle/ch_ppocr_v3_det/utils.py @@ -1,22 +1,7 @@ -""" -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -import sys +from typing import List, Optional, Tuple import cv2 import numpy as np @@ -24,116 +9,49 @@ from shapely.geometry import Polygon -class NormalizeImage: - """normalize image such as substract mean, divide std""" +class DetPreProcess: + def __init__(self, limit_side_len: int = 736, limit_type: str = "min"): + self.mean = np.array([0.485, 0.456, 0.406]) + self.std = np.array([0.229, 0.224, 0.225]) + self.scale = 1 / 255.0 - def __init__(self, scale=None, mean=None, std=None, order="chw"): - if isinstance(scale, str): - scale = eval(scale) + self.limit_side_len = limit_side_len + self.limit_type = limit_type - self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) - mean = mean if mean is not None else [0.485, 0.456, 0.406] - std = std if std is not None else [0.229, 0.224, 0.225] - - shape = (3, 1, 1) if order == "chw" else (1, 1, 3) - self.mean = np.array(mean).reshape(shape).astype("float32") - self.std = np.array(std).reshape(shape).astype("float32") - - def __call__(self, data): - img = np.array(data["image"]).astype(np.float32) - data["image"] = (img * self.scale - self.mean) / self.std - return data - - -class ToCHWImage: - """convert hwc image to chw image""" - - def __init__(self): - pass - - def __call__(self, data): - img = np.array(data["image"]) - data["image"] = img.transpose((2, 0, 1)) - return data - - -class KeepKeys: - def __init__(self, keep_keys): - self.keep_keys = keep_keys - - def __call__(self, data): - data_list = [] - for key in self.keep_keys: - data_list.append(data[key]) - return data_list - - -class DetResizeForTest: - def __init__(self, **kwargs): - self.resize_type = 0 - if "image_shape" in kwargs: - self.image_shape = kwargs["image_shape"] - self.resize_type = 1 - elif "limit_side_len" in kwargs: - self.limit_side_len = kwargs.get("limit_side_len", 736) - self.limit_type = kwargs.get("limit_type", "min") - - if "resize_long" in kwargs: - self.resize_type = 2 - self.resize_long = kwargs.get("resize_long", 960) - else: - self.limit_side_len = kwargs.get("limit_side_len", 736) - self.limit_type = kwargs.get("limit_type", "min") - - def __call__(self, data): - img = data["image"] - src_h, src_w = img.shape[:2] + def __call__(self, img: np.ndarray) -> Optional[np.ndarray]: + resized_img = self.resize(img) + if resized_img is None: + return None - if self.resize_type == 0: - img, [ratio_h, ratio_w] = self.resize_image_type0(img) - elif self.resize_type == 2: - img, [ratio_h, ratio_w] = self.resize_image_type2(img) - else: - img, [ratio_h, ratio_w] = self.resize_image_type1(img) + img = self.normalize(resized_img) + img = self.permute(img) + img = np.expand_dims(img, axis=0).astype(np.float32) + return img - data["image"] = img - data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w]) - return data + def normalize(self, img: np.ndarray) -> np.ndarray: + return (img.astype("float32") * self.scale - self.mean) / self.std - def resize_image_type1(self, img): - resize_h, resize_w = self.image_shape - ori_h, ori_w = img.shape[:2] # (h, w, c) - ratio_h = float(resize_h) / ori_h - ratio_w = float(resize_w) / ori_w - img = cv2.resize(img, (int(resize_w), int(resize_h))) - return img, [ratio_h, ratio_w] + def permute(self, img: np.ndarray) -> np.ndarray: + return img.transpose((2, 0, 1)) - def resize_image_type0(self, img): - """ - resize image to a size multiple of 32 which is required by the network - args: - img(array): array with shape [h, w, c] - return(tuple): - img, (ratio_h, ratio_w) - """ - limit_side_len = self.limit_side_len + def resize(self, img: np.ndarray) -> Optional[np.ndarray]: + """resize image to a size multiple of 32 which is required by the network""" h, w = img.shape[:2] - # limit the max side if self.limit_type == "max": - if max(h, w) > limit_side_len: + if max(h, w) > self.limit_side_len: if h > w: - ratio = float(limit_side_len) / h + ratio = float(self.limit_side_len) / h else: - ratio = float(limit_side_len) / w + ratio = float(self.limit_side_len) / w else: ratio = 1.0 else: - if min(h, w) < limit_side_len: + if min(h, w) < self.limit_side_len: if h < w: - ratio = float(limit_side_len) / h + ratio = float(self.limit_side_len) / h else: - ratio = float(limit_side_len) / w + ratio = float(self.limit_side_len) / w else: ratio = 1.0 @@ -145,72 +63,16 @@ def resize_image_type0(self, img): try: if int(resize_w) <= 0 or int(resize_h) <= 0: - return None, (None, None) + return None img = cv2.resize(img, (int(resize_w), int(resize_h))) - except: - print(img.shape, resize_w, resize_h) - sys.exit(0) - - ratio_h = resize_h / float(h) - ratio_w = resize_w / float(w) - return img, [ratio_h, ratio_w] - - def resize_image_type2(self, img): - h, w = img.shape[:2] - - resize_w = w - resize_h = h - - # Fix the longer side - if resize_h > resize_w: - ratio = float(self.resize_long) / resize_h - else: - ratio = float(self.resize_long) / resize_w - - resize_h = int(resize_h * ratio) - resize_w = int(resize_w * ratio) - - max_stride = 128 - resize_h = (resize_h + max_stride - 1) // max_stride * max_stride - resize_w = (resize_w + max_stride - 1) // max_stride * max_stride - img = cv2.resize(img, (int(resize_w), int(resize_h))) - ratio_h = resize_h / float(h) - ratio_w = resize_w / float(w) - - return img, [ratio_h, ratio_w] - - -def transform(data, ops=None): - """transform""" - if ops is None: - ops = [] - - for op in ops: - data = op(data) - if data is None: - return None - return data - + except Exception as exc: + raise ResizeImgError from exc -def create_operators(op_param_dict): - """ - create operators based on the config - """ - ops = [] - for op_name, param in op_param_dict.items(): - if param is None: - param = {} - op = eval(op_name)(**param) - ops.append(op) - return ops + return img -def draw_text_det_res(dt_boxes, img_path): - src_im = cv2.imread(img_path) - for box in dt_boxes: - box = np.array(box).astype(np.int32).reshape(-1, 2) - cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - return src_im +class ResizeImgError(Exception): + pass class DBPostProcess: @@ -218,12 +80,12 @@ class DBPostProcess: def __init__( self, - thresh=0.3, - box_thresh=0.7, - max_candidates=1000, - unclip_ratio=2.0, - score_mode="fast", - use_dilation=False, + thresh: float = 0.3, + box_thresh: float = 0.7, + max_candidates: int = 1000, + unclip_ratio: float = 2.0, + score_mode: str = "fast", + use_dilation: bool = False, ): self.thresh = thresh self.box_thresh = box_thresh @@ -232,18 +94,33 @@ def __init__( self.min_size = 3 self.score_mode = score_mode + self.dilation_kernel = None if use_dilation: self.dilation_kernel = np.array([[1, 1], [1, 1]]) - else: - self.dilation_kernel = None - def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + def __call__( + self, pred: np.ndarray, ori_shape: Tuple[int, int] + ) -> Tuple[np.ndarray, List[float]]: + src_h, src_w = ori_shape + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + mask = segmentation[0] + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[0]).astype(np.uint8), self.dilation_kernel + ) + boxes, scores = self.boxes_from_bitmap(pred[0], mask, src_w, src_h) + return boxes, scores + + def boxes_from_bitmap( + self, pred: np.ndarray, bitmap: np.ndarray, dest_width: int, dest_height: int + ) -> Tuple[np.ndarray, List[float]]: """ - _bitmap: single map with shape (1, H, W), + bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} """ - bitmap = _bitmap height, width = bitmap.shape outs = cv2.findContours( @@ -256,45 +133,35 @@ def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): num_contours = min(len(contours), self.max_candidates) - boxes = [] - scores = [] + boxes, scores = [], [] for index in range(num_contours): contour = contours[index] points, sside = self.get_mini_boxes(contour) if sside < self.min_size: continue - points = np.array(points) + if self.score_mode == "fast": score = self.box_score_fast(pred, points.reshape(-1, 2)) else: score = self.box_score_slow(pred, contour) + if self.box_thresh > score: continue - box = self.unclip(points).reshape(-1, 1, 2) + box = self.unclip(points) box, sside = self.get_mini_boxes(box) if sside < self.min_size + 2: continue - box = np.array(box) box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) box[:, 1] = np.clip( np.round(box[:, 1] / height * dest_height), 0, dest_height ) - boxes.append(box.astype(np.int16)) + boxes.append(box.astype(np.int32)) scores.append(score) - return np.array(boxes, dtype=np.int16), scores + return np.array(boxes, dtype=np.int32), scores - def unclip(self, box): - unclip_ratio = self.unclip_ratio - poly = Polygon(box) - distance = poly.area * unclip_ratio / poly.length - offset = pyclipper.PyclipperOffset() - offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) - expanded = np.array(offset.Execute(distance)) - return expanded - - def get_mini_boxes(self, contour): + def get_mini_boxes(self, contour: np.ndarray) -> Tuple[np.ndarray, float]: bounding_box = cv2.minAreaRect(contour) points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) @@ -305,6 +172,7 @@ def get_mini_boxes(self, contour): else: index_1 = 1 index_4 = 0 + if points[3][1] > points[2][1]: index_2 = 2 index_3 = 3 @@ -312,10 +180,13 @@ def get_mini_boxes(self, contour): index_2 = 3 index_3 = 2 - box = [points[index_1], points[index_2], points[index_3], points[index_4]] + box = np.array( + [points[index_1], points[index_2], points[index_3], points[index_4]] + ) return box, min(bounding_box[1]) - def box_score_fast(self, bitmap, _box): + @staticmethod + def box_score_fast(bitmap: np.ndarray, _box: np.ndarray) -> float: h, w = bitmap.shape[:2] box = _box.copy() xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) @@ -329,10 +200,8 @@ def box_score_fast(self, bitmap, _box): cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] - def box_score_slow(self, bitmap, contour): - """ - box_score_slow: use polyon mean score as the mean score - """ + def box_score_slow(self, bitmap: np.ndarray, contour: np.ndarray) -> float: + """use polyon mean score as the mean score""" h, w = bitmap.shape[:2] contour = contour.copy() contour = np.reshape(contour, (-1, 2)) @@ -350,23 +219,11 @@ def box_score_slow(self, bitmap, contour): cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] - def __call__(self, pred, shape_list): - pred = pred[:, 0, :, :] - segmentation = pred > self.thresh - - boxes_batch = [] - for batch_index in range(pred.shape[0]): - src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] - if self.dilation_kernel is not None: - mask = cv2.dilate( - np.array(segmentation[batch_index]).astype(np.uint8), - self.dilation_kernel, - ) - else: - mask = segmentation[batch_index] - boxes, scores = self.boxes_from_bitmap( - pred[batch_index], mask, src_w, src_h - ) - - boxes_batch.append({"points": boxes}) - return boxes_batch + def unclip(self, box: np.ndarray) -> np.ndarray: + unclip_ratio = self.unclip_ratio + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)).reshape((-1, 1, 2)) + return expanded diff --git a/python/rapidocr_paddle/ch_ppocr_v3_rec/text_recognize.py b/python/rapidocr_paddle/ch_ppocr_v3_rec/text_recognize.py index 853f9e295..45249883a 100644 --- a/python/rapidocr_paddle/ch_ppocr_v3_rec/text_recognize.py +++ b/python/rapidocr_paddle/ch_ppocr_v3_rec/text_recognize.py @@ -15,7 +15,7 @@ import math import time from pathlib import Path -from typing import List +from typing import List, Tuple, Union import cv2 import numpy as np @@ -27,16 +27,18 @@ class TextRecognizer: def __init__(self, config): - self.session = PaddleInferSession(config, mode="rec") + self.rec_image_shape = config["rec_img_shape"] + self.rec_batch_num = config["rec_batch_num"] dict_path = str(Path(__file__).parent / "ppocr_keys_v1.txt") self.character_dict_path = config.get("rec_keys_path", dict_path) - self.postprocess_op = CTCLabelDecode(self.character_dict_path) + self.postprocess_op = CTCLabelDecode(character_path=self.character_dict_path) - self.rec_batch_num = config["rec_batch_num"] - self.rec_image_shape = config["rec_img_shape"] + self.infer = PaddleInferSession(config, mode="rec") - def __call__(self, img_list: List[np.ndarray]): + def __call__( + self, img_list: Union[np.ndarray, List[np.ndarray]] + ) -> Tuple[List[Tuple[str, float]], float]: if isinstance(img_list, np.ndarray): img_list = [img_list] @@ -47,7 +49,7 @@ def __call__(self, img_list: List[np.ndarray]): indices = np.argsort(np.array(width_list)) img_num = len(img_list) - rec_res = [["", 0.0]] * img_num + rec_res = [("", 0.0)] * img_num batch_num = self.rec_batch_num elapse = 0 @@ -66,7 +68,7 @@ def __call__(self, img_list: List[np.ndarray]): norm_img_batch = np.concatenate(norm_img_batch).astype(np.float32) starttime = time.time() - preds = self.session(norm_img_batch)[0] + preds = self.infer(norm_img_batch)[0] rec_result = self.postprocess_op(preds) for rno, one_res in enumerate(rec_result): @@ -74,7 +76,7 @@ def __call__(self, img_list: List[np.ndarray]): elapse += time.time() - starttime return rec_res, elapse - def resize_norm_img(self, img, max_wh_ratio): + def resize_norm_img(self, img: np.ndarray, max_wh_ratio: float) -> np.ndarray: img_channel, img_height, img_width = self.rec_image_shape assert img_channel == img.shape[2] diff --git a/python/rapidocr_paddle/ch_ppocr_v3_rec/utils.py b/python/rapidocr_paddle/ch_ppocr_v3_rec/utils.py index b590f5d70..7d9be4836 100644 --- a/python/rapidocr_paddle/ch_ppocr_v3_rec/utils.py +++ b/python/rapidocr_paddle/ch_ppocr_v3_rec/utils.py @@ -1,75 +1,101 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +from pathlib import Path +from typing import List, Optional, Tuple, Union + import numpy as np class CTCLabelDecode: - """Convert between text-label and text-index""" + def __init__( + self, + character: Optional[List[str]] = None, + character_path: Union[str, Path, None] = None, + ): + self.character = self.get_character(character, character_path) + self.dict = {char: i for i, char in enumerate(self.character)} - def __init__(self, character_dict_path): - super(CTCLabelDecode, self).__init__() + def __call__(self, preds: np.ndarray) -> List[Tuple[str, float]]: + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + return text - self.character_str = [] - assert character_dict_path is not None, "character_dict_path should not be None" + def get_character( + self, + character: Optional[List[str]] = None, + character_path: Union[str, Path, None] = None, + ) -> List[str]: + if character is None and character_path is None: + raise ValueError("character must not be None") - if isinstance(character_dict_path, str): - with open(character_dict_path, "rb") as fin: - lines = fin.readlines() - for line in lines: - line = line.decode("utf-8").strip("\n").strip("\r\n") - self.character_str.append(line) - else: - self.character_str = character_dict_path - self.character_str.append(" ") + character_list = None + if character: + character_list = character - dict_character = self.add_special_char(self.character_str) - self.character = dict_character + if character_path: + character_list = self.read_character_file(character_path) - self.dict = {} - for i, char in enumerate(dict_character): - self.dict[char] = i + if character_list is None: + raise ValueError("character must not be None") - def __call__(self, preds, label=None): - preds_idx = preds.argmax(axis=2) - preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) - if label is None: - return text - label = self.decode(label) - return text, label + character_list = self.insert_special_char( + character_list, " ", len(character_list) + ) + character_list = self.insert_special_char(character_list, "blank", 0) + return character_list - def add_special_char(self, dict_character): - dict_character = ["blank"] + dict_character - return dict_character + @staticmethod + def read_character_file(character_path: Union[str, Path]) -> List[str]: + character_list = [] + with open(character_path, "rb") as f: + lines = f.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_list.append(line) + return character_list - def get_ignored_tokens(self): - return [0] # for ctc blank + @staticmethod + def insert_special_char( + character_list: List[str], special_char: str, loc: int = -1 + ) -> List[str]: + character_list.insert(loc, special_char) + return character_list - def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + def decode( + self, + text_index: np.ndarray, + text_prob: Optional[np.ndarray] = None, + is_remove_duplicate: bool = False, + ) -> List[Tuple[str, float]]: """convert text-index into text-label.""" - result_list = [] ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] in ignored_tokens: + char_list, conf_list = [], [] + cur_pred_ids = text_index[batch_idx] + for idx, cur_idx in enumerate(cur_pred_ids): + if cur_idx in ignored_tokens: continue + if is_remove_duplicate: # only for predict - if ( - idx > 0 - and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx] - ): + if idx > 0 and cur_pred_ids[idx - 1] == cur_idx: continue - char_list.append(self.character[int(text_index[batch_idx][idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: + + char_list.append(self.character[int(cur_idx)]) + + if text_prob is None: conf_list.append(1) + else: + conf_list.append(text_prob[batch_idx][idx]) + text = "".join(char_list) result_list.append((text, np.mean(conf_list if any(conf_list) else [0]))) return result_list + + @staticmethod + def get_ignored_tokens() -> List[int]: + return [0] # for ctc blank diff --git a/python/rapidocr_paddle/main.py b/python/rapidocr_paddle/main.py index 410639fc4..262e299c2 100644 --- a/python/rapidocr_paddle/main.py +++ b/python/rapidocr_paddle/main.py @@ -3,7 +3,7 @@ # @Contact: liekkaskono@163.com import copy from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import cv2 import numpy as np @@ -15,6 +15,7 @@ LoadImage, UpdateParameters, VisRes, + get_logger, init_args, read_yaml, update_model_path, @@ -22,6 +23,7 @@ root_dir = Path(__file__).resolve().parent DEFAULT_CFG_PATH = root_dir / "config.yaml" +logger = get_logger("RapidOCR") class RapidOCR: @@ -60,7 +62,7 @@ def __call__( use_cls: Optional[bool] = None, use_rec: Optional[bool] = None, **kwargs, - ): + ) -> Tuple[Optional[List[List[Union[Any, str]]]], Optional[List[float]]]: use_det = self.use_det if use_det is None else use_det use_cls = self.use_cls if use_cls is None else use_cls use_rec = self.use_rec if use_rec is None else use_rec @@ -120,9 +122,8 @@ def maybe_add_letterbox(self, img: np.ndarray) -> Tuple[np.ndarray, int]: return img, 0 def auto_text_det( - self, - img: np.ndarray, - ) -> Tuple[Optional[np.ndarray], float, Optional[List[np.ndarray]]]: + self, img: np.ndarray + ) -> Tuple[Optional[List[np.ndarray]], float]: dt_boxes, det_elapse = self.text_det(img) if dt_boxes is None or len(dt_boxes) < 1: return None, 0.0 @@ -130,8 +131,10 @@ def auto_text_det( dt_boxes = self.sorted_boxes(dt_boxes) return dt_boxes, det_elapse - def get_crop_img_list(self, img, dt_boxes): - def get_rotate_crop_image(img, points): + def get_crop_img_list( + self, img: np.ndarray, dt_boxes: List[np.ndarray] + ) -> List[np.ndarray]: + def get_rotate_crop_image(img: np.ndarray, points: np.ndarray) -> np.ndarray: img_crop_width = int( max( np.linalg.norm(points[0] - points[1]), @@ -144,14 +147,14 @@ def get_rotate_crop_image(img, points): np.linalg.norm(points[1] - points[2]), ) ) - pts_std = np.float32( + pts_std = np.array( [ [0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height], ] - ) + ).astype(np.float32) M = cv2.getPerspectiveTransform(points, pts_std) dst_img = cv2.warpPerspective( img, @@ -173,7 +176,7 @@ def get_rotate_crop_image(img, points): return img_crop_list @staticmethod - def sorted_boxes(dt_boxes): + def sorted_boxes(dt_boxes: np.ndarray) -> List[np.ndarray]: """ Sort text boxes in order from top to bottom, left to right args: @@ -199,8 +202,14 @@ def sorted_boxes(dt_boxes): return _boxes def get_final_res( - self, dt_boxes, cls_res, rec_res, det_elapse, cls_elapse, rec_elapse - ): + self, + dt_boxes: Optional[List[np.ndarray]], + cls_res: Optional[List[List[Union[str, float]]]], + rec_res: Optional[List[Tuple[str, float]]], + det_elapse: float, + cls_elapse: float, + rec_elapse: float, + ) -> Tuple[Optional[List[List[Union[Any, str]]]], Optional[List[float]]]: if dt_boxes is None and rec_res is None and cls_res is not None: return cls_res, [cls_elapse] @@ -214,7 +223,7 @@ def get_final_res( return [box.tolist() for box in dt_boxes], [det_elapse] dt_boxes, rec_res = self.filter_result(dt_boxes, rec_res) - if len(dt_boxes) <= 0: + if not dt_boxes or not rec_res or len(dt_boxes) <= 0: return None, None ocr_res = [ @@ -222,7 +231,14 @@ def get_final_res( ], [det_elapse, cls_elapse, rec_elapse] return ocr_res - def filter_result(self, dt_boxes, rec_res): + def filter_result( + self, + dt_boxes: Optional[List[np.ndarray]], + rec_res: Optional[List[Tuple[str, float]]], + ) -> Tuple[Optional[List[np.ndarray]], Optional[List[Tuple[str, float]]]]: + if dt_boxes is None or rec_res is None: + return None, None + filter_boxes, filter_rec_res = [], [] for box, rec_reuslt in zip(dt_boxes, rec_res): text, score = rec_reuslt @@ -246,10 +262,10 @@ def main(): use_cls=use_cls, use_rec=use_rec, ) - print(result) + logger.info(result) if args.print_cost: - print(elapse_list) + logger.info(elapse_list) if args.vis_res: vis = VisRes() @@ -260,7 +276,7 @@ def main(): boxes, *_ = list(zip(*result)) vis_img = vis(args.img_path, boxes) cv2.imwrite(str(save_path), vis_img) - print(f"The vis result has saved in {save_path}") + logger.info("The vis result has saved in %s", save_path) elif use_det and use_rec: font_path = Path(args.vis_font_path) @@ -270,7 +286,7 @@ def main(): boxes, txts, scores = list(zip(*result)) vis_img = vis(args.img_path, boxes, txts, scores, font_path=font_path) cv2.imwrite(str(save_path), vis_img) - print(f"The vis result has saved in {save_path}") + logger.info("The vis result has saved in %s", save_path) if __name__ == "__main__": diff --git a/python/rapidocr_paddle/utils/load_image.py b/python/rapidocr_paddle/utils/load_image.py index 056605303..f34b549fe 100644 --- a/python/rapidocr_paddle/utils/load_image.py +++ b/python/rapidocr_paddle/utils/load_image.py @@ -3,7 +3,7 @@ # @Contact: liekkaskono@163.com from io import BytesIO from pathlib import Path -from typing import Union +from typing import Any, Union import cv2 import numpy as np @@ -55,7 +55,7 @@ def img_to_ndarray(self, img: Image.Image) -> np.ndarray: return np.array(img) return np.array(img) - def convert_img(self, img: np.ndarray, origin_img_type): + def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: if img.ndim == 2: return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) diff --git a/python/rapidocr_paddle/utils/parse_parameters.py b/python/rapidocr_paddle/utils/parse_parameters.py index dcf457960..e6aa36ab6 100644 --- a/python/rapidocr_paddle/utils/parse_parameters.py +++ b/python/rapidocr_paddle/utils/parse_parameters.py @@ -3,7 +3,7 @@ # @Contact: liekkaskono@163.com import argparse from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np from PIL import Image @@ -12,7 +12,7 @@ InputType = Union[str, np.ndarray, bytes, Path, Image.Image] -def update_model_path(config): +def update_model_path(config: Dict[str, Any]) -> Dict[str, Any]: key = "model_path" config["Det"][key] = str(root_dir / config["Det"][key]) config["Rec"][key] = str(root_dir / config["Rec"][key]) @@ -121,7 +121,12 @@ def __call__(self, config, **kwargs): global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) new_config = { "Global": self.update_global_params(config["Global"], global_dict), - "Det": self.update_params(config["Det"], det_dict, "det_", None), + "Det": self.update_params( + config["Det"], + det_dict, + "det_", + ["det_model_path", "det_use_cuda"], + ), "Cls": self.update_params( config["Cls"], cls_dict, diff --git a/python/rapidocr_paddle/utils/vis_res.py b/python/rapidocr_paddle/utils/vis_res.py index 405beabad..bd18031f1 100644 --- a/python/rapidocr_paddle/utils/vis_res.py +++ b/python/rapidocr_paddle/utils/vis_res.py @@ -29,7 +29,7 @@ def __call__( scores: Optional[Tuple[float]] = None, font_path: Optional[str] = None, ) -> np.ndarray: - if txts is None and scores is None: + if txts is None: return self.draw_dt_boxes(img_content, dt_boxes) return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores, font_path) @@ -52,7 +52,7 @@ def draw_ocr_box_txt( self, img_content: InputType, dt_boxes: np.ndarray, - txts: Optional[Union[List[str], Tuple[str]]] = None, + txts: Union[List[str], Tuple[str]], scores: Optional[Tuple[float]] = None, font_path: Optional[str] = None, ) -> np.ndarray: