From d60fe5b41525670eb8aba58ce4ab689c6367780e Mon Sep 17 00:00:00 2001 From: SWHL Date: Fri, 17 May 2024 16:34:16 +0800 Subject: [PATCH 1/8] chore(rapidocr_onnxruntime): Optim tips about OrtInferSession --- python/rapidocr_onnxruntime/utils.py | 119 ++++++++++++++++++++++----- 1 file changed, 99 insertions(+), 20 deletions(-) diff --git a/python/rapidocr_onnxruntime/utils.py b/python/rapidocr_onnxruntime/utils.py index aad76b693..5ca74a6b8 100644 --- a/python/rapidocr_onnxruntime/utils.py +++ b/python/rapidocr_onnxruntime/utils.py @@ -2,6 +2,7 @@ # @Author: SWHL # @Contact: liekkaskono@163.com import argparse +import logging import math import os import platform @@ -34,6 +35,8 @@ class OrtInferSession: def __init__(self, config): + self.logger = get_logger("OrtInferSession") + sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False @@ -69,25 +72,17 @@ def _get_ep_list(self) -> List[Tuple[str, str]]: } EP_list = [(CPU_EP, cpu_provider_opts)] - self.use_cuda = ( - self.cfg_use_cuda and get_device() == "GPU" and CUDA_EP in had_providers - ) cuda_provider_opts = { "device_id": 0, "arena_extend_strategy": "kNextPowerOfTwo", "cudnn_conv_algo_search": "EXHAUSTIVE", "do_copy_in_default_stream": True, } + self.use_cuda = self._check_cuda_condition(had_providers) if self.use_cuda: EP_list.insert(0, (CUDA_EP, cuda_provider_opts)) - # check windows 10 or above - self.use_directml = ( - self.cfg_use_dml - and platform.system() == "Windows" - and int(platform.release().split(".")[0]) >= 10 - and DIRECTML_EP in had_providers - ) + self.use_directml = self._check_dml_condition(had_providers) if self.use_directml: print( "Windows 10 or above detected, try to use DirectML as primary provider" @@ -98,21 +93,90 @@ def _get_ep_list(self) -> List[Tuple[str, str]]: EP_list.insert(0, (DIRECTML_EP, directml_options)) return EP_list + def _check_cuda_condition(self, had_providers: List[str]) -> bool: + if not self.cfg_use_cuda: + return False + + cur_device = get_device() + if cur_device != "GPU" or CUDA_EP not in had_providers: + self.logger.warning( + "%s is not in available providers (%s)", CUDA_EP, had_providers + ) + self.logger.info("If you want to use GPU acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." + ) + self.logger.info( + "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." + ) + self.logger.info( + "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", + CUDA_EP, + ) + return False + + return True + + def _check_dml_condition(self, had_providers: List[str]) -> bool: + if not self.cfg_use_dml: + return False + + cur_os = platform.system() + if cur_os != "Windows": + self.logger.warning( + "DirectML is only supported in Windows OS. The current OS is %s", cur_os + ) + return False + + cur_window_version = int(platform.release().split(".")[0]) + if cur_window_version < 10: + self.logger.warning( + "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s", + cur_window_version, + ) + return False + + if DIRECTML_EP not in had_providers: + self.logger.warning( + "%s is not in available providers (%s). ", DIRECTML_EP, had_providers + ) + self.logger.info("If you want to use DirectML acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", + DIRECTML_EP, + ) + return False + + return True + def _verify_providers(self) -> None: session_providers = self.session.get_providers() + first_provider = session_providers[0] - if self.use_cuda and session_providers[0] != CUDA_EP: - warnings.warn( - f"{CUDA_EP} is not avaiable for current env, the inference part is automatically shifted to be executed under {CPU_EP}.\n" - "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " - "you can check their relations from the offical web site: " - "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", - RuntimeWarning, + if self.use_cuda and first_provider != CUDA_EP: + self.logger.warning( + "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", + CUDA_EP, + first_provider, ) - if self.use_directml and session_providers[0] != DIRECTML_EP: - warnings.warn( - "DirectML is not available for the current environment, the inference part is automatically shifted to be executed under other EP.\n" + if self.use_directml and first_provider != DIRECTML_EP: + self.logger.warning( + "%s is not available for the current environment, the inference part is automatically shifted to be executed under %s.", + DIRECTML_EP, + first_provider, ) def __call__(self, input_content: np.ndarray) -> np.ndarray: @@ -607,3 +671,18 @@ def get_char_size(font, char_str: str) -> float: raise ValueError( "The Pillow ImageFont instance has not getsize or getlength func." ) + + +def get_logger(name: str): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger From de44841377de5f6a05a4e3d99f5cee78dc930c82 Mon Sep 17 00:00:00 2001 From: SWHL Date: Sat, 18 May 2024 15:12:10 +0800 Subject: [PATCH 2/8] refactor(rapidocr_onnxruntime): Decoupling utils.py files --- python/rapidocr_onnxruntime/main.py | 10 +- python/rapidocr_onnxruntime/utils.py | 688 ------------------ python/rapidocr_onnxruntime/utils/__init__.py | 19 + .../utils/infer_engine.py | 220 ++++++ .../rapidocr_onnxruntime/utils/load_image.py | 123 ++++ python/rapidocr_onnxruntime/utils/logger.py | 19 + .../utils/parse_parameters.py | 215 ++++++ python/rapidocr_onnxruntime/utils/vis_res.py | 143 ++++ 8 files changed, 745 insertions(+), 692 deletions(-) delete mode 100644 python/rapidocr_onnxruntime/utils.py create mode 100644 python/rapidocr_onnxruntime/utils/__init__.py create mode 100644 python/rapidocr_onnxruntime/utils/infer_engine.py create mode 100644 python/rapidocr_onnxruntime/utils/load_image.py create mode 100644 python/rapidocr_onnxruntime/utils/logger.py create mode 100644 python/rapidocr_onnxruntime/utils/parse_parameters.py create mode 100644 python/rapidocr_onnxruntime/utils/vis_res.py diff --git a/python/rapidocr_onnxruntime/main.py b/python/rapidocr_onnxruntime/main.py index 410639fc4..ecd9344a2 100644 --- a/python/rapidocr_onnxruntime/main.py +++ b/python/rapidocr_onnxruntime/main.py @@ -15,6 +15,7 @@ LoadImage, UpdateParameters, VisRes, + get_logger, init_args, read_yaml, update_model_path, @@ -22,6 +23,7 @@ root_dir = Path(__file__).resolve().parent DEFAULT_CFG_PATH = root_dir / "config.yaml" +logger = get_logger("RapidOCR") class RapidOCR: @@ -246,10 +248,10 @@ def main(): use_cls=use_cls, use_rec=use_rec, ) - print(result) + logger.info(result) if args.print_cost: - print(elapse_list) + logger.info(elapse_list) if args.vis_res: vis = VisRes() @@ -260,7 +262,7 @@ def main(): boxes, *_ = list(zip(*result)) vis_img = vis(args.img_path, boxes) cv2.imwrite(str(save_path), vis_img) - print(f"The vis result has saved in {save_path}") + logger.info("The vis result has saved in %s", save_path) elif use_det and use_rec: font_path = Path(args.vis_font_path) @@ -270,7 +272,7 @@ def main(): boxes, txts, scores = list(zip(*result)) vis_img = vis(args.img_path, boxes, txts, scores, font_path=font_path) cv2.imwrite(str(save_path), vis_img) - print(f"The vis result has saved in {save_path}") + logger.info("The vis result has saved in %s", save_path) if __name__ == "__main__": diff --git a/python/rapidocr_onnxruntime/utils.py b/python/rapidocr_onnxruntime/utils.py deleted file mode 100644 index 5ca74a6b8..000000000 --- a/python/rapidocr_onnxruntime/utils.py +++ /dev/null @@ -1,688 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import argparse -import logging -import math -import os -import platform -import random -import traceback -import warnings -from io import BytesIO -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union - -import cv2 -import numpy as np -import yaml -from onnxruntime import ( - GraphOptimizationLevel, - InferenceSession, - SessionOptions, - get_available_providers, - get_device, -) -from PIL import Image, ImageDraw, ImageFont, UnidentifiedImageError - -root_dir = Path(__file__).resolve().parent -InputType = Union[str, np.ndarray, bytes, Path, Image.Image] - -CPU_EP = "CPUExecutionProvider" -CUDA_EP = "CUDAExecutionProvider" -DIRECTML_EP = "DmlExecutionProvider" - - -class OrtInferSession: - def __init__(self, config): - self.logger = get_logger("OrtInferSession") - - sess_opt = SessionOptions() - sess_opt.log_severity_level = 4 - sess_opt.enable_cpu_mem_arena = False - sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - - cpu_nums = os.cpu_count() - intra_op_num_threads = config.get("intra_op_num_threads", -1) - if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: - sess_opt.intra_op_num_threads = intra_op_num_threads - - inter_op_num_threads = config.get("inter_op_num_threads", -1) - if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: - sess_opt.inter_op_num_threads = inter_op_num_threads - - model_path = config.get("model_path", None) - self._verify_model(model_path) - - self.cfg_use_cuda = config.get("use_cuda", None) - self.cfg_use_dml = config.get("use_dml", None) - EP_list = self._get_ep_list() - - self.session = InferenceSession( - model_path, sess_options=sess_opt, providers=EP_list - ) - - self._verify_providers() - - def _get_ep_list(self) -> List[Tuple[str, str]]: - had_providers: List[str] = get_available_providers() - - cpu_provider_opts = { - "arena_extend_strategy": "kSameAsRequested", - } - EP_list = [(CPU_EP, cpu_provider_opts)] - - cuda_provider_opts = { - "device_id": 0, - "arena_extend_strategy": "kNextPowerOfTwo", - "cudnn_conv_algo_search": "EXHAUSTIVE", - "do_copy_in_default_stream": True, - } - self.use_cuda = self._check_cuda_condition(had_providers) - if self.use_cuda: - EP_list.insert(0, (CUDA_EP, cuda_provider_opts)) - - self.use_directml = self._check_dml_condition(had_providers) - if self.use_directml: - print( - "Windows 10 or above detected, try to use DirectML as primary provider" - ) - directml_options = ( - cuda_provider_opts if self.use_cuda else cpu_provider_opts - ) - EP_list.insert(0, (DIRECTML_EP, directml_options)) - return EP_list - - def _check_cuda_condition(self, had_providers: List[str]) -> bool: - if not self.cfg_use_cuda: - return False - - cur_device = get_device() - if cur_device != "GPU" or CUDA_EP not in had_providers: - self.logger.warning( - "%s is not in available providers (%s)", CUDA_EP, had_providers - ) - self.logger.info("If you want to use GPU acceleration, you must do:") - self.logger.info( - "First, uninstall all onnxruntime pakcages in current environment." - ) - self.logger.info( - "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." - ) - self.logger.info( - "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." - ) - self.logger.info( - "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html" - ) - self.logger.info( - "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", - CUDA_EP, - ) - return False - - return True - - def _check_dml_condition(self, had_providers: List[str]) -> bool: - if not self.cfg_use_dml: - return False - - cur_os = platform.system() - if cur_os != "Windows": - self.logger.warning( - "DirectML is only supported in Windows OS. The current OS is %s", cur_os - ) - return False - - cur_window_version = int(platform.release().split(".")[0]) - if cur_window_version < 10: - self.logger.warning( - "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s", - cur_window_version, - ) - return False - - if DIRECTML_EP not in had_providers: - self.logger.warning( - "%s is not in available providers (%s). ", DIRECTML_EP, had_providers - ) - self.logger.info("If you want to use DirectML acceleration, you must do:") - self.logger.info( - "First, uninstall all onnxruntime pakcages in current environment." - ) - self.logger.info( - "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" - ) - self.logger.info( - "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", - DIRECTML_EP, - ) - return False - - return True - - def _verify_providers(self) -> None: - session_providers = self.session.get_providers() - first_provider = session_providers[0] - - if self.use_cuda and first_provider != CUDA_EP: - self.logger.warning( - "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", - CUDA_EP, - first_provider, - ) - - if self.use_directml and first_provider != DIRECTML_EP: - self.logger.warning( - "%s is not available for the current environment, the inference part is automatically shifted to be executed under %s.", - DIRECTML_EP, - first_provider, - ) - - def __call__(self, input_content: np.ndarray) -> np.ndarray: - input_dict = dict(zip(self.get_input_names(), [input_content])) - try: - return self.session.run(self.get_output_names(), input_dict) - except Exception as e: - error_info = traceback.format_exc() - raise ONNXRuntimeError(error_info) from e - - def get_input_names( - self, - ): - return [v.name for v in self.session.get_inputs()] - - def get_output_names( - self, - ): - return [v.name for v in self.session.get_outputs()] - - def get_character_list(self, key: str = "character"): - meta_dict = self.session.get_modelmeta().custom_metadata_map - return meta_dict[key].splitlines() - - def have_key(self, key: str = "character") -> bool: - meta_dict = self.session.get_modelmeta().custom_metadata_map - if key in meta_dict.keys(): - return True - return False - - @staticmethod - def _verify_model(model_path: Union[str, Path, None]): - if model_path is None: - raise ValueError("model_path is None!") - - model_path = Path(model_path) - - if not model_path.exists(): - raise FileNotFoundError(f"{model_path} does not exists.") - - if not model_path.is_file(): - raise FileExistsError(f"{model_path} is not a file.") - - -class ONNXRuntimeError(Exception): - pass - - -class LoadImage: - def __init__( - self, - ): - pass - - def __call__(self, img: InputType) -> np.ndarray: - if not isinstance(img, InputType.__args__): - raise LoadImageError( - f"The img type {type(img)} does not in {InputType.__args__}" - ) - - origin_img_type = type(img) - img = self.load_img(img) - img = self.convert_img(img, origin_img_type) - return img - - def load_img(self, img: InputType) -> np.ndarray: - if isinstance(img, (str, Path)): - self.verify_exist(img) - try: - img = self.img_to_ndarray(Image.open(img)) - except UnidentifiedImageError as e: - raise LoadImageError(f"cannot identify image file {img}") from e - return img - - if isinstance(img, bytes): - img = self.img_to_ndarray(Image.open(BytesIO(img))) - return img - - if isinstance(img, np.ndarray): - return img - - if isinstance(img, Image.Image): - return self.img_to_ndarray(img) - - raise LoadImageError(f"{type(img)} is not supported!") - - def img_to_ndarray(self, img: Image.Image) -> np.ndarray: - if img.mode == "1": - img = img.convert("L") - return np.array(img) - return np.array(img) - - def convert_img(self, img: np.ndarray, origin_img_type): - if img.ndim == 2: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if img.ndim == 3: - channel = img.shape[2] - if channel == 1: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if channel == 2: - return self.cvt_two_to_three(img) - - if channel == 3: - if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): - return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - return img - - if channel == 4: - return self.cvt_four_to_three(img) - - raise LoadImageError( - f"The channel({channel}) of the img is not in [1, 2, 3, 4]" - ) - - raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") - - @staticmethod - def cvt_two_to_three(img: np.ndarray) -> np.ndarray: - """gray + alpha → BGR""" - img_gray = img[..., 0] - img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) - - img_alpha = img[..., 1] - not_a = cv2.bitwise_not(img_alpha) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - """RGBA → BGR""" - r, g, b, a = cv2.split(img) - new_img = cv2.merge((b, g, r)) - - not_a = cv2.bitwise_not(a) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(new_img, new_img, mask=a) - - mean_color = np.mean(new_img) - if mean_color <= 0.0: - new_img = cv2.add(new_img, not_a) - else: - new_img = cv2.bitwise_not(new_img) - return new_img - - @staticmethod - def verify_exist(file_path: Union[str, Path]): - if not Path(file_path).exists(): - raise LoadImageError(f"{file_path} does not exist.") - - -class LoadImageError(Exception): - pass - - -def read_yaml(yaml_path): - with open(yaml_path, "rb") as f: - data = yaml.load(f, Loader=yaml.Loader) - return data - - -def update_model_path(config): - key = "model_path" - config["Det"][key] = str(root_dir / config["Det"][key]) - config["Rec"][key] = str(root_dir / config["Rec"][key]) - config["Cls"][key] = str(root_dir / config["Cls"][key]) - return config - - -def init_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-img", "--img_path", type=str, default=None, required=True) - parser.add_argument("-p", "--print_cost", action="store_true", default=False) - - global_group = parser.add_argument_group(title="Global") - global_group.add_argument("--text_score", type=float, default=0.5) - - global_group.add_argument("--no_det", action="store_true", default=False) - global_group.add_argument("--no_cls", action="store_true", default=False) - global_group.add_argument("--no_rec", action="store_true", default=False) - - global_group.add_argument("--print_verbose", action="store_true", default=False) - global_group.add_argument("--min_height", type=int, default=30) - global_group.add_argument("--width_height_ratio", type=int, default=8) - - global_group.add_argument("--intra_op_num_threads", type=int, default=-1) - global_group.add_argument("--inter_op_num_threads", type=int, default=-1) - - det_group = parser.add_argument_group(title="Det") - det_group.add_argument( - "--det_use_cuda", - action="store_true", - default=False, - help="Whether to use cuda. The prerequisite is: first uninstall all onnxruntime packages and install only the onnxruntime-gpu library.", - ) - det_group.add_argument( - "--det_use_dml", - action="store_true", - default=False, - help="Whether to use DirectML. The prerequisite is: the operating system is Windows 10+. First uninstall all onnxruntime packages and install only the onnxruntime-directml library.", - ) - det_group.add_argument("--det_model_path", type=str, default=None) - det_group.add_argument("--det_limit_side_len", type=float, default=736) - det_group.add_argument( - "--det_limit_type", type=str, default="min", choices=["max", "min"] - ) - det_group.add_argument("--det_thresh", type=float, default=0.3) - det_group.add_argument("--det_box_thresh", type=float, default=0.5) - det_group.add_argument("--det_unclip_ratio", type=float, default=1.6) - det_group.add_argument( - "--det_donot_use_dilation", action="store_true", default=False - ) - det_group.add_argument( - "--det_score_mode", type=str, default="fast", choices=["slow", "fast"] - ) - - cls_group = parser.add_argument_group(title="Cls") - cls_group.add_argument( - "--cls_use_cuda", - action="store_true", - default=False, - help="Whether to use cuda. The prerequisite is: first uninstall all onnxruntime packages and install only the onnxruntime-gpu library.", - ) - cls_group.add_argument( - "--cls_use_dml", - action="store_true", - default=False, - help="Whether to use DirectML. The prerequisite is: the operating system is Windows 10+. First uninstall all onnxruntime packages and install only the onnxruntime-directml library.", - ) - cls_group.add_argument("--cls_model_path", type=str, default=None) - cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) - cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) - cls_group.add_argument("--cls_batch_num", type=int, default=6) - cls_group.add_argument("--cls_thresh", type=float, default=0.9) - - rec_group = parser.add_argument_group(title="Rec") - rec_group.add_argument( - "--rec_use_cuda", - action="store_true", - default=False, - help="Whether to use cuda. The prerequisite is: first uninstall all onnxruntime packages and install only the onnxruntime-gpu library.", - ) - rec_group.add_argument( - "--rec_use_dml", - action="store_true", - default=False, - help="Whether to use DirectML. The prerequisite is: the operating system is Windows 10+. First uninstall all onnxruntime packages and install only the onnxruntime-directml library.", - ) - rec_group.add_argument("--rec_model_path", type=str, default=None) - rec_group.add_argument("--rec_keys_path", type=str, default=None) - rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) - rec_group.add_argument("--rec_batch_num", type=int, default=6) - - vis_group = parser.add_argument_group(title="Visual Result") - vis_group.add_argument("-vis", "--vis_res", action="store_true", default=False) - vis_group.add_argument( - "--vis_font_path", - type=str, - default=None, - help="When -vis is True, the font_path must have value.", - ) - vis_group.add_argument( - "--vis_save_path", - type=str, - default=".", - help="The directory of saving the vis image.", - ) - - args = parser.parse_args() - return args - - -class UpdateParameters: - def __init__(self) -> None: - pass - - def parse_kwargs(self, **kwargs): - global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {} - for k, v in kwargs.items(): - if k.startswith("det"): - k = k.split("det_")[1] - if k == "donot_use_dilation": - k = "use_dilation" - v = not v - - det_dict[k] = v - elif k.startswith("cls"): - cls_dict[k] = v - elif k.startswith("rec"): - rec_dict[k] = v - else: - global_dict[k] = v - return global_dict, det_dict, cls_dict, rec_dict - - def __call__(self, config, **kwargs): - global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) - new_config = { - "Global": self.update_global_params(config["Global"], global_dict), - "Det": self.update_params(config["Det"], det_dict, "det_", None), - "Cls": self.update_params( - config["Cls"], - cls_dict, - "cls_", - ["cls_label_list", "cls_model_path", "cls_use_cuda"], - ), - "Rec": self.update_params( - config["Rec"], rec_dict, "rec_", ["rec_model_path", "rec_use_cuda"] - ), - } - - update_params = ["intra_op_num_threads", "inter_op_num_threads"] - new_config = self.update_global_to_module( - config, update_params, src="Global", dsts=["Det", "Cls", "Rec"] - ) - return new_config - - def update_global_to_module( - self, config, params: List[str], src: str, dsts: List[str] - ): - for dst in dsts: - for param in params: - config[dst].update({param: config[src][param]}) - return config - - def update_global_params(self, config, global_dict): - if global_dict: - config.update(global_dict) - return config - - def update_params( - self, - config, - param_dict: Dict[str, str], - prefix: str, - need_remove_prefix: Optional[List[str]] = None, - ): - if not param_dict: - return config - - filter_dict = self.remove_prefix(param_dict, prefix, need_remove_prefix) - model_path = filter_dict.get("model_path", None) - if not model_path: - filter_dict["model_path"] = str(root_dir / config["model_path"]) - - config.update(filter_dict) - return config - - @staticmethod - def remove_prefix( - config: Dict[str, str], - prefix: str, - need_remove_prefix: Optional[List[str]] = None, - ) -> Dict[str, str]: - if not need_remove_prefix: - return config - - new_rec_dict = {} - for k, v in config.items(): - if k in need_remove_prefix: - k = k.split(prefix)[1] - new_rec_dict[k] = v - return new_rec_dict - - -class VisRes: - def __init__(self, text_score: float = 0.5): - self.text_score = text_score - self.load_img = LoadImage() - - def __call__( - self, - img_content: InputType, - dt_boxes: np.ndarray, - txts: Optional[Union[List[str], Tuple[str]]] = None, - scores: Optional[Tuple[float]] = None, - font_path: Optional[str] = None, - ) -> np.ndarray: - if txts is None and scores is None: - return self.draw_dt_boxes(img_content, dt_boxes) - return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores, font_path) - - def draw_dt_boxes(self, img_content: InputType, dt_boxes: np.ndarray) -> np.ndarray: - img = self.load_img(img_content) - - for idx, box in enumerate(dt_boxes): - color = self.get_random_color() - - points = np.array(box) - cv2.polylines(img, np.int32([points]), 1, color=color, thickness=1) - - start_point = round(points[0][0]), round(points[0][1]) - cv2.putText( - img, f"{idx}", start_point, cv2.FONT_HERSHEY_SIMPLEX, 1, color, 3 - ) - return img - - def draw_ocr_box_txt( - self, - img_content: InputType, - dt_boxes: np.ndarray, - txts: Optional[Union[List[str], Tuple[str]]] = None, - scores: Optional[Tuple[float]] = None, - font_path: Optional[str] = None, - ) -> np.ndarray: - font_path = self.get_font_path(font_path) - - image = Image.fromarray(self.load_img(img_content)) - h, w = image.height, image.width - if image.mode == "L": - image = image.convert("RGB") - - img_left = image.copy() - img_right = Image.new("RGB", (w, h), (255, 255, 255)) - - random.seed(0) - draw_left = ImageDraw.Draw(img_left) - draw_right = ImageDraw.Draw(img_right) - for idx, (box, txt) in enumerate(zip(dt_boxes, txts)): - if scores is not None and float(scores[idx]) < self.text_score: - continue - - color = self.get_random_color() - - box_list = np.array(box).reshape(8).tolist() - draw_left.polygon(box_list, fill=color) - draw_right.polygon(box_list, outline=color) - - box_height = self.get_box_height(box) - box_width = self.get_box_width(box) - if box_height > 2 * box_width: - font_size = max(int(box_width * 0.9), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - cur_y = box[0][1] - - for c in txt: - draw_right.text( - (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font - ) - cur_y += self.get_char_size(font, c) - else: - font_size = max(int(box_height * 0.8), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) - - img_left = Image.blend(image, img_left, 0.5) - img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) - img_show.paste(img_left, (0, 0, w, h)) - img_show.paste(img_right, (w, 0, w * 2, h)) - return np.array(img_show) - - @staticmethod - def get_font_path(font_path: Optional[Union[str, Path]] = None) -> str: - if font_path is None or not Path(font_path).exists(): - raise FileNotFoundError( - f"The {font_path} does not exists! \n" - f"You could download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing" - ) - return str(font_path) - - @staticmethod - def get_random_color() -> Tuple[int, int, int]: - return ( - random.randint(0, 255), - random.randint(0, 255), - random.randint(0, 255), - ) - - @staticmethod - def get_box_height(box: List[List[float]]) -> float: - return math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) - - @staticmethod - def get_box_width(box: List[List[float]]) -> float: - return math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) - - @staticmethod - def get_char_size(font, char_str: str) -> float: - # compatible with Pillow v9 and v10. - if hasattr(font, "getsize"): - get_size_func = getattr(font, "getsize") - return get_size_func(char_str)[1] - - if hasattr(font, "getlength"): - get_size_func = getattr(font, "getlength") - return get_size_func(char_str) - - raise ValueError( - "The Pillow ImageFont instance has not getsize or getlength func." - ) - - -def get_logger(name: str): - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - - fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" - format_str = logging.Formatter(fmt) - - sh = logging.StreamHandler() - sh.setLevel(logging.DEBUG) - - logger.addHandler(sh) - sh.setFormatter(format_str) - return logger diff --git a/python/rapidocr_onnxruntime/utils/__init__.py b/python/rapidocr_onnxruntime/utils/__init__.py new file mode 100644 index 000000000..cf5ca6994 --- /dev/null +++ b/python/rapidocr_onnxruntime/utils/__init__.py @@ -0,0 +1,19 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from pathlib import Path +from typing import Dict, Union + +import yaml + +from .infer_engine import OrtInferSession +from .load_image import LoadImage, LoadImageError +from .logger import get_logger +from .parse_parameters import UpdateParameters, init_args, update_model_path +from .vis_res import VisRes + + +def read_yaml(yaml_path: Union[str, Path]) -> Dict[str, Dict]: + with open(yaml_path, "rb") as f: + data = yaml.load(f, Loader=yaml.Loader) + return data diff --git a/python/rapidocr_onnxruntime/utils/infer_engine.py b/python/rapidocr_onnxruntime/utils/infer_engine.py new file mode 100644 index 000000000..13b3e14d4 --- /dev/null +++ b/python/rapidocr_onnxruntime/utils/infer_engine.py @@ -0,0 +1,220 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import os +import platform +import traceback +from pathlib import Path +from typing import List, Tuple, Union + +import numpy as np +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) + +from .logger import get_logger + +CPU_EP = "CPUExecutionProvider" +CUDA_EP = "CUDAExecutionProvider" +DIRECTML_EP = "DmlExecutionProvider" + + +class OrtInferSession: + def __init__(self, config): + self.logger = get_logger("OrtInferSession") + + model_path = config.get("model_path", None) + self._verify_model(model_path) + + self.cfg_use_cuda = config.get("use_cuda", None) + self.cfg_use_dml = config.get("use_dml", None) + + self.had_providers: List[str] = get_available_providers() + EP_list = self._get_ep_list() + + sess_opt = self._init_sess_opts(config) + self.session = InferenceSession( + model_path, sess_options=sess_opt, providers=EP_list + ) + self._verify_providers() + + @staticmethod + def _init_sess_opts(config): + sess_opt = SessionOptions() + sess_opt.log_severity_level = 4 + sess_opt.enable_cpu_mem_arena = False + sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + + cpu_nums = os.cpu_count() + intra_op_num_threads = config.get("intra_op_num_threads", -1) + if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: + sess_opt.intra_op_num_threads = intra_op_num_threads + + inter_op_num_threads = config.get("inter_op_num_threads", -1) + if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: + sess_opt.inter_op_num_threads = inter_op_num_threads + + return sess_opt + + def _get_ep_list(self) -> List[Tuple[str, str]]: + cpu_provider_opts = { + "arena_extend_strategy": "kSameAsRequested", + } + EP_list = [(CPU_EP, cpu_provider_opts)] + + cuda_provider_opts = { + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": True, + } + self.use_cuda = self._check_cuda() + if self.use_cuda: + EP_list.insert(0, (CUDA_EP, cuda_provider_opts)) + + self.use_directml = self._check_dml() + if self.use_directml: + self.logger.info( + "Windows 10 or above detected, try to use DirectML as primary provider" + ) + directml_options = ( + cuda_provider_opts if self.use_cuda else cpu_provider_opts + ) + EP_list.insert(0, (DIRECTML_EP, directml_options)) + return EP_list + + def _check_cuda(self) -> bool: + if not self.cfg_use_cuda: + return False + + cur_device = get_device() + if cur_device == "GPU" and CUDA_EP in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s), default use of %s inference.", + CUDA_EP, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("If you want to use GPU acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." + ) + self.logger.info( + "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." + ) + self.logger.info( + "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", + CUDA_EP, + ) + return False + + def _check_dml(self) -> bool: + if not self.cfg_use_dml: + return False + + cur_os = platform.system() + if cur_os != "Windows": + self.logger.warning( + "DirectML is only supported in Windows OS. The current OS is %s", cur_os + ) + return False + + cur_window_version = int(platform.release().split(".")[0]) + if cur_window_version < 10: + self.logger.warning( + "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s", + cur_window_version, + ) + return False + + if DIRECTML_EP in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s), default use of %s inference.", + DIRECTML_EP, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("If you want to use DirectML acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", + DIRECTML_EP, + ) + return False + + def _verify_providers(self) -> None: + session_providers = self.session.get_providers() + first_provider = session_providers[0] + + if self.use_cuda and first_provider != CUDA_EP: + self.logger.warning( + "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", + CUDA_EP, + first_provider, + ) + + if self.use_directml and first_provider != DIRECTML_EP: + self.logger.warning( + "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", + DIRECTML_EP, + first_provider, + ) + + def __call__(self, input_content: np.ndarray) -> np.ndarray: + input_dict = dict(zip(self.get_input_names(), [input_content])) + try: + return self.session.run(self.get_output_names(), input_dict) + except Exception as e: + error_info = traceback.format_exc() + raise ONNXRuntimeError(error_info) from e + + def get_input_names(self): + return [v.name for v in self.session.get_inputs()] + + def get_output_names(self): + return [v.name for v in self.session.get_outputs()] + + def get_character_list(self, key: str = "character"): + meta_dict = self.session.get_modelmeta().custom_metadata_map + return meta_dict[key].splitlines() + + def have_key(self, key: str = "character") -> bool: + meta_dict = self.session.get_modelmeta().custom_metadata_map + if key in meta_dict.keys(): + return True + return False + + @staticmethod + def _verify_model(model_path: Union[str, Path, None]): + if model_path is None: + raise ValueError("model_path is None!") + + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exists.") + + if not model_path.is_file(): + raise FileExistsError(f"{model_path} is not a file.") + + +class ONNXRuntimeError(Exception): + pass diff --git a/python/rapidocr_onnxruntime/utils/load_image.py b/python/rapidocr_onnxruntime/utils/load_image.py new file mode 100644 index 000000000..056605303 --- /dev/null +++ b/python/rapidocr_onnxruntime/utils/load_image.py @@ -0,0 +1,123 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from io import BytesIO +from pathlib import Path +from typing import Union + +import cv2 +import numpy as np +from PIL import Image, UnidentifiedImageError + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class LoadImage: + def __init__(self): + pass + + def __call__(self, img: InputType) -> np.ndarray: + if not isinstance(img, InputType.__args__): + raise LoadImageError( + f"The img type {type(img)} does not in {InputType.__args__}" + ) + + origin_img_type = type(img) + img = self.load_img(img) + img = self.convert_img(img, origin_img_type) + return img + + def load_img(self, img: InputType) -> np.ndarray: + if isinstance(img, (str, Path)): + self.verify_exist(img) + try: + img = self.img_to_ndarray(Image.open(img)) + except UnidentifiedImageError as e: + raise LoadImageError(f"cannot identify image file {img}") from e + return img + + if isinstance(img, bytes): + img = self.img_to_ndarray(Image.open(BytesIO(img))) + return img + + if isinstance(img, np.ndarray): + return img + + if isinstance(img, Image.Image): + return self.img_to_ndarray(img) + + raise LoadImageError(f"{type(img)} is not supported!") + + def img_to_ndarray(self, img: Image.Image) -> np.ndarray: + if img.mode == "1": + img = img.convert("L") + return np.array(img) + return np.array(img) + + def convert_img(self, img: np.ndarray, origin_img_type): + if img.ndim == 2: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if img.ndim == 3: + channel = img.shape[2] + if channel == 1: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if channel == 2: + return self.cvt_two_to_three(img) + + if channel == 3: + if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): + return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + if channel == 4: + return self.cvt_four_to_three(img) + + raise LoadImageError( + f"The channel({channel}) of the img is not in [1, 2, 3, 4]" + ) + + raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") + + @staticmethod + def cvt_two_to_three(img: np.ndarray) -> np.ndarray: + """gray + alpha → BGR""" + img_gray = img[..., 0] + img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) + + img_alpha = img[..., 1] + not_a = cv2.bitwise_not(img_alpha) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def cvt_four_to_three(img: np.ndarray) -> np.ndarray: + """RGBA → BGR""" + r, g, b, a = cv2.split(img) + new_img = cv2.merge((b, g, r)) + + not_a = cv2.bitwise_not(a) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(new_img, new_img, mask=a) + + mean_color = np.mean(new_img) + if mean_color <= 0.0: + new_img = cv2.add(new_img, not_a) + else: + new_img = cv2.bitwise_not(new_img) + return new_img + + @staticmethod + def verify_exist(file_path: Union[str, Path]): + if not Path(file_path).exists(): + raise LoadImageError(f"{file_path} does not exist.") + + +class LoadImageError(Exception): + pass diff --git a/python/rapidocr_onnxruntime/utils/logger.py b/python/rapidocr_onnxruntime/utils/logger.py new file mode 100644 index 000000000..ffd1cd04d --- /dev/null +++ b/python/rapidocr_onnxruntime/utils/logger.py @@ -0,0 +1,19 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import logging + + +def get_logger(name: str): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger diff --git a/python/rapidocr_onnxruntime/utils/parse_parameters.py b/python/rapidocr_onnxruntime/utils/parse_parameters.py new file mode 100644 index 000000000..395866cdf --- /dev/null +++ b/python/rapidocr_onnxruntime/utils/parse_parameters.py @@ -0,0 +1,215 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import argparse +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +from PIL import Image + +root_dir = Path(__file__).resolve().parent.parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +def update_model_path(config): + key = "model_path" + config["Det"][key] = str(root_dir / config["Det"][key]) + config["Rec"][key] = str(root_dir / config["Rec"][key]) + config["Cls"][key] = str(root_dir / config["Cls"][key]) + return config + + +def init_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-img", "--img_path", type=str, default=None, required=True) + parser.add_argument("-p", "--print_cost", action="store_true", default=False) + + global_group = parser.add_argument_group(title="Global") + global_group.add_argument("--text_score", type=float, default=0.5) + + global_group.add_argument("--no_det", action="store_true", default=False) + global_group.add_argument("--no_cls", action="store_true", default=False) + global_group.add_argument("--no_rec", action="store_true", default=False) + + global_group.add_argument("--print_verbose", action="store_true", default=False) + global_group.add_argument("--min_height", type=int, default=30) + global_group.add_argument("--width_height_ratio", type=int, default=8) + + global_group.add_argument("--intra_op_num_threads", type=int, default=-1) + global_group.add_argument("--inter_op_num_threads", type=int, default=-1) + + det_group = parser.add_argument_group(title="Det") + det_group.add_argument( + "--det_use_cuda", + action="store_true", + default=False, + help="Whether to use cuda. The prerequisite is: first uninstall all onnxruntime packages and install only the onnxruntime-gpu library.", + ) + det_group.add_argument( + "--det_use_dml", + action="store_true", + default=False, + help="Whether to use DirectML. The prerequisite is: the operating system is Windows 10+. First uninstall all onnxruntime packages and install only the onnxruntime-directml library.", + ) + det_group.add_argument("--det_model_path", type=str, default=None) + det_group.add_argument("--det_limit_side_len", type=float, default=736) + det_group.add_argument( + "--det_limit_type", type=str, default="min", choices=["max", "min"] + ) + det_group.add_argument("--det_thresh", type=float, default=0.3) + det_group.add_argument("--det_box_thresh", type=float, default=0.5) + det_group.add_argument("--det_unclip_ratio", type=float, default=1.6) + det_group.add_argument( + "--det_donot_use_dilation", action="store_true", default=False + ) + det_group.add_argument( + "--det_score_mode", type=str, default="fast", choices=["slow", "fast"] + ) + + cls_group = parser.add_argument_group(title="Cls") + cls_group.add_argument( + "--cls_use_cuda", + action="store_true", + default=False, + help="Whether to use cuda. The prerequisite is: first uninstall all onnxruntime packages and install only the onnxruntime-gpu library.", + ) + cls_group.add_argument( + "--cls_use_dml", + action="store_true", + default=False, + help="Whether to use DirectML. The prerequisite is: the operating system is Windows 10+. First uninstall all onnxruntime packages and install only the onnxruntime-directml library.", + ) + cls_group.add_argument("--cls_model_path", type=str, default=None) + cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) + cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) + cls_group.add_argument("--cls_batch_num", type=int, default=6) + cls_group.add_argument("--cls_thresh", type=float, default=0.9) + + rec_group = parser.add_argument_group(title="Rec") + rec_group.add_argument( + "--rec_use_cuda", + action="store_true", + default=False, + help="Whether to use cuda. The prerequisite is: first uninstall all onnxruntime packages and install only the onnxruntime-gpu library.", + ) + rec_group.add_argument( + "--rec_use_dml", + action="store_true", + default=False, + help="Whether to use DirectML. The prerequisite is: the operating system is Windows 10+. First uninstall all onnxruntime packages and install only the onnxruntime-directml library.", + ) + rec_group.add_argument("--rec_model_path", type=str, default=None) + rec_group.add_argument("--rec_keys_path", type=str, default=None) + rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) + rec_group.add_argument("--rec_batch_num", type=int, default=6) + + vis_group = parser.add_argument_group(title="Visual Result") + vis_group.add_argument("-vis", "--vis_res", action="store_true", default=False) + vis_group.add_argument( + "--vis_font_path", + type=str, + default=None, + help="When -vis is True, the font_path must have value.", + ) + vis_group.add_argument( + "--vis_save_path", + type=str, + default=".", + help="The directory of saving the vis image.", + ) + + args = parser.parse_args() + return args + + +class UpdateParameters: + def __init__(self) -> None: + pass + + def parse_kwargs(self, **kwargs): + global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {} + for k, v in kwargs.items(): + if k.startswith("det"): + k = k.split("det_")[1] + if k == "donot_use_dilation": + k = "use_dilation" + v = not v + + det_dict[k] = v + elif k.startswith("cls"): + cls_dict[k] = v + elif k.startswith("rec"): + rec_dict[k] = v + else: + global_dict[k] = v + return global_dict, det_dict, cls_dict, rec_dict + + def __call__(self, config, **kwargs): + global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) + new_config = { + "Global": self.update_global_params(config["Global"], global_dict), + "Det": self.update_params(config["Det"], det_dict, "det_", None), + "Cls": self.update_params( + config["Cls"], + cls_dict, + "cls_", + ["cls_label_list", "cls_model_path", "cls_use_cuda"], + ), + "Rec": self.update_params( + config["Rec"], rec_dict, "rec_", ["rec_model_path", "rec_use_cuda"] + ), + } + + update_params = ["intra_op_num_threads", "inter_op_num_threads"] + new_config = self.update_global_to_module( + config, update_params, src="Global", dsts=["Det", "Cls", "Rec"] + ) + return new_config + + def update_global_to_module( + self, config, params: List[str], src: str, dsts: List[str] + ): + for dst in dsts: + for param in params: + config[dst].update({param: config[src][param]}) + return config + + def update_global_params(self, config, global_dict): + if global_dict: + config.update(global_dict) + return config + + def update_params( + self, + config, + param_dict: Dict[str, str], + prefix: str, + need_remove_prefix: Optional[List[str]] = None, + ): + if not param_dict: + return config + + filter_dict = self.remove_prefix(param_dict, prefix, need_remove_prefix) + model_path = filter_dict.get("model_path", None) + if not model_path: + filter_dict["model_path"] = str(root_dir / config["model_path"]) + + config.update(filter_dict) + return config + + @staticmethod + def remove_prefix( + config: Dict[str, str], + prefix: str, + need_remove_prefix: Optional[List[str]] = None, + ) -> Dict[str, str]: + if not need_remove_prefix: + return config + + new_rec_dict = {} + for k, v in config.items(): + if k in need_remove_prefix: + k = k.split(prefix)[1] + new_rec_dict[k] = v + return new_rec_dict diff --git a/python/rapidocr_onnxruntime/utils/vis_res.py b/python/rapidocr_onnxruntime/utils/vis_res.py new file mode 100644 index 000000000..405beabad --- /dev/null +++ b/python/rapidocr_onnxruntime/utils/vis_res.py @@ -0,0 +1,143 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import math +import random +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from .load_image import LoadImage + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class VisRes: + def __init__(self, text_score: float = 0.5): + self.text_score = text_score + self.load_img = LoadImage() + + def __call__( + self, + img_content: InputType, + dt_boxes: np.ndarray, + txts: Optional[Union[List[str], Tuple[str]]] = None, + scores: Optional[Tuple[float]] = None, + font_path: Optional[str] = None, + ) -> np.ndarray: + if txts is None and scores is None: + return self.draw_dt_boxes(img_content, dt_boxes) + return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores, font_path) + + def draw_dt_boxes(self, img_content: InputType, dt_boxes: np.ndarray) -> np.ndarray: + img = self.load_img(img_content) + + for idx, box in enumerate(dt_boxes): + color = self.get_random_color() + + points = np.array(box) + cv2.polylines(img, np.int32([points]), 1, color=color, thickness=1) + + start_point = round(points[0][0]), round(points[0][1]) + cv2.putText( + img, f"{idx}", start_point, cv2.FONT_HERSHEY_SIMPLEX, 1, color, 3 + ) + return img + + def draw_ocr_box_txt( + self, + img_content: InputType, + dt_boxes: np.ndarray, + txts: Optional[Union[List[str], Tuple[str]]] = None, + scores: Optional[Tuple[float]] = None, + font_path: Optional[str] = None, + ) -> np.ndarray: + font_path = self.get_font_path(font_path) + + image = Image.fromarray(self.load_img(img_content)) + h, w = image.height, image.width + if image.mode == "L": + image = image.convert("RGB") + + img_left = image.copy() + img_right = Image.new("RGB", (w, h), (255, 255, 255)) + + random.seed(0) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(dt_boxes, txts)): + if scores is not None and float(scores[idx]) < self.text_score: + continue + + color = self.get_random_color() + + box_list = np.array(box).reshape(8).tolist() + draw_left.polygon(box_list, fill=color) + draw_right.polygon(box_list, outline=color) + + box_height = self.get_box_height(box) + box_width = self.get_box_width(box) + if box_height > 2 * box_width: + font_size = max(int(box_width * 0.9), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + cur_y = box[0][1] + + for c in txt: + draw_right.text( + (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font + ) + cur_y += self.get_char_size(font, c) + else: + font_size = max(int(box_height * 0.8), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + return np.array(img_show) + + @staticmethod + def get_font_path(font_path: Optional[Union[str, Path]] = None) -> str: + if font_path is None or not Path(font_path).exists(): + raise FileNotFoundError( + f"The {font_path} does not exists! \n" + f"You could download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing" + ) + return str(font_path) + + @staticmethod + def get_random_color() -> Tuple[int, int, int]: + return ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) + + @staticmethod + def get_box_height(box: List[List[float]]) -> float: + return math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) + + @staticmethod + def get_box_width(box: List[List[float]]) -> float: + return math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) + + @staticmethod + def get_char_size(font, char_str: str) -> float: + # compatible with Pillow v9 and v10. + if hasattr(font, "getsize"): + get_size_func = getattr(font, "getsize") + return get_size_func(char_str)[1] + + if hasattr(font, "getlength"): + get_size_func = getattr(font, "getlength") + return get_size_func(char_str) + + raise ValueError( + "The Pillow ImageFont instance has not getsize or getlength func." + ) From 7528fa96607db440128fed1ce97d9acce63c9a34 Mon Sep 17 00:00:00 2001 From: SWHL Date: Sat, 18 May 2024 15:27:13 +0800 Subject: [PATCH 3/8] chore(rapidocr_onnxruntime): Optim tips of infer_engine --- python/rapidocr_onnxruntime/utils/infer_engine.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/rapidocr_onnxruntime/utils/infer_engine.py b/python/rapidocr_onnxruntime/utils/infer_engine.py index 13b3e14d4..51ee123a2 100644 --- a/python/rapidocr_onnxruntime/utils/infer_engine.py +++ b/python/rapidocr_onnxruntime/utils/infer_engine.py @@ -96,7 +96,7 @@ def _check_cuda(self) -> bool: return True self.logger.warning( - "%s is not in available providers (%s), default use of %s inference.", + "%s is not in available providers (%s). Use %s inference by default.", CUDA_EP, self.had_providers, self.had_providers[0], @@ -127,15 +127,18 @@ def _check_dml(self) -> bool: cur_os = platform.system() if cur_os != "Windows": self.logger.warning( - "DirectML is only supported in Windows OS. The current OS is %s", cur_os + "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", + cur_os, + self.had_providers[0], ) return False cur_window_version = int(platform.release().split(".")[0]) if cur_window_version < 10: self.logger.warning( - "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s", + "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", cur_window_version, + self.had_providers[0], ) return False @@ -143,7 +146,7 @@ def _check_dml(self) -> bool: return True self.logger.warning( - "%s is not in available providers (%s), default use of %s inference.", + "%s is not in available providers (%s). Use %s inference by default.", DIRECTML_EP, self.had_providers, self.had_providers[0], From b9c90ee837d362f2788b2f9f7f21a704f6340ab5 Mon Sep 17 00:00:00 2001 From: Dolen <87627379+HiDolen@users.noreply.github.com> Date: Sat, 18 May 2024 15:29:35 +0800 Subject: [PATCH 4/8] fix(python): Modify the way CTCLabelDecode calculates confidence (#179) --- python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py | 2 +- python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py | 2 +- python/rapidocr_paddle/ch_ppocr_v3_rec/utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py index 8173787ba..6e63a87bd 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_v3_rec/utils.py @@ -69,5 +69,5 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): else: conf_list.append(1) text = "".join(char_list) - result_list.append((text, np.mean(conf_list + [1e-50]))) + result_list.append((text, np.mean(conf_list if any(conf_list) else [0]))) return result_list diff --git a/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py b/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py index 9587af350..f26c36951 100644 --- a/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py +++ b/python/rapidocr_openvino/ch_ppocr_v3_rec/utils.py @@ -67,5 +67,5 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): else: conf_list.append(1) text = "".join(char_list) - result_list.append((text, np.mean(conf_list + [1e-50]))) + result_list.append((text, np.mean(conf_list if any(conf_list) else [0]))) return result_list diff --git a/python/rapidocr_paddle/ch_ppocr_v3_rec/utils.py b/python/rapidocr_paddle/ch_ppocr_v3_rec/utils.py index 1cde51931..b590f5d70 100644 --- a/python/rapidocr_paddle/ch_ppocr_v3_rec/utils.py +++ b/python/rapidocr_paddle/ch_ppocr_v3_rec/utils.py @@ -71,5 +71,5 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): else: conf_list.append(1) text = "".join(char_list) - result_list.append((text, np.mean(conf_list + [1e-50]))) + result_list.append((text, np.mean(conf_list if any(conf_list) else [0]))) return result_list From a6521e19104bab2c63140e0148900f5512860b0e Mon Sep 17 00:00:00 2001 From: SWHL Date: Sat, 18 May 2024 16:01:21 +0800 Subject: [PATCH 5/8] fix(python): Update unit testing --- python/tests/test_ort.py | 12 ++++++------ python/tests/test_paddle.py | 12 ++++++------ python/tests/test_vino.py | 12 ++++++------ 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/python/tests/test_ort.py b/python/tests/test_ort.py index e943d19eb..03f8e0206 100644 --- a/python/tests/test_ort.py +++ b/python/tests/test_ort.py @@ -88,7 +88,7 @@ def test_only_rec(): def test_det_rec(): result, _ = engine(img_path, use_det=True, use_cls=False, use_rec=True) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_cls_rec(): @@ -104,7 +104,7 @@ def test_det_cls_rec(): result, _ = engine(img) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_empty(): @@ -124,20 +124,20 @@ def test_zeros(): def test_input_str(): result, _ = engine(str(img_path)) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_input_bytes(): with open(img_path, "rb") as f: result, _ = engine(f.read()) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_input_path(): result, _ = engine(img_path) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_input_parameters(): @@ -190,7 +190,7 @@ def test_input_three_ndim_one_channel(): result, _ = engine(img) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 17 def test_det(): diff --git a/python/tests/test_paddle.py b/python/tests/test_paddle.py index 4be8ba39f..b07fb88ed 100644 --- a/python/tests/test_paddle.py +++ b/python/tests/test_paddle.py @@ -88,7 +88,7 @@ def test_only_rec(): def test_det_rec(): result, _ = engine(img_path, use_det=True, use_cls=False, use_rec=True) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_cls_rec(): @@ -104,7 +104,7 @@ def test_det_cls_rec(): result, _ = engine(img) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_empty(): @@ -124,20 +124,20 @@ def test_zeros(): def test_input_str(): result, _ = engine(str(img_path)) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_input_bytes(): with open(img_path, "rb") as f: result, _ = engine(f.read()) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_input_path(): result, _ = engine(img_path) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_input_parameters(): @@ -190,7 +190,7 @@ def test_input_three_ndim_one_channel(): result, _ = engine(img) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 17 def test_det(): diff --git a/python/tests/test_vino.py b/python/tests/test_vino.py index 11b9f4299..506ab0e51 100644 --- a/python/tests/test_vino.py +++ b/python/tests/test_vino.py @@ -88,7 +88,7 @@ def test_only_rec(): def test_det_rec(): result, _ = engine(img_path, use_det=True, use_cls=False, use_rec=True) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_cls_rec(): @@ -104,7 +104,7 @@ def test_det_cls_rec(): result, _ = engine(img) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_empty(): @@ -124,20 +124,20 @@ def test_zeros(): def test_input_str(): result, _ = engine(str(img_path)) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_input_bytes(): with open(img_path, "rb") as f: result, _ = engine(f.read()) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_input_path(): result, _ = engine(img_path) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 18 def test_input_parameters(): @@ -190,7 +190,7 @@ def test_input_three_ndim_one_channel(): result, _ = engine(img) assert result[0][1] == "正品促销" - assert len(result) == 16 + assert len(result) == 17 def test_det(): From 160914d68e1a836091c6e5bd3d91c206c238b8f0 Mon Sep 17 00:00:00 2001 From: SWHL Date: Sat, 18 May 2024 16:15:18 +0800 Subject: [PATCH 6/8] refactor(rapidocr_openvino): Decoupling utils.py files --- python/rapidocr_openvino/utils.py | 458 ------------------ python/rapidocr_openvino/utils/__init__.py | 19 + .../rapidocr_openvino/utils/infer_engine.py | 45 ++ python/rapidocr_openvino/utils/load_image.py | 123 +++++ python/rapidocr_openvino/utils/logger.py | 19 + .../utils/parse_parameters.py | 178 +++++++ python/rapidocr_openvino/utils/vis_res.py | 143 ++++++ 7 files changed, 527 insertions(+), 458 deletions(-) delete mode 100644 python/rapidocr_openvino/utils.py create mode 100644 python/rapidocr_openvino/utils/__init__.py create mode 100644 python/rapidocr_openvino/utils/infer_engine.py create mode 100644 python/rapidocr_openvino/utils/load_image.py create mode 100644 python/rapidocr_openvino/utils/logger.py create mode 100644 python/rapidocr_openvino/utils/parse_parameters.py create mode 100644 python/rapidocr_openvino/utils/vis_res.py diff --git a/python/rapidocr_openvino/utils.py b/python/rapidocr_openvino/utils.py deleted file mode 100644 index 22b563745..000000000 --- a/python/rapidocr_openvino/utils.py +++ /dev/null @@ -1,458 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import argparse -import math -import os -import random -from io import BytesIO -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union - -import cv2 -import numpy as np -import yaml -from openvino.runtime import Core -from PIL import Image, ImageDraw, ImageFont, UnidentifiedImageError - -root_dir = Path(__file__).resolve().parent -InputType = Union[str, np.ndarray, bytes, Path, Image.Image] - - -class OpenVINOInferSession: - def __init__(self, config): - core = Core() - - self._verify_model(config["model_path"]) - model_onnx = core.read_model(config["model_path"]) - - cpu_nums = os.cpu_count() - infer_num_threads = config.get("inference_num_threads", -1) - if infer_num_threads != -1 and 1 <= infer_num_threads <= cpu_nums: - core.set_property("CPU", {"INFERENCE_NUM_THREADS": str(infer_num_threads)}) - - compile_model = core.compile_model(model=model_onnx, device_name="CPU") - self.session = compile_model.create_infer_request() - - def __call__(self, input_content: np.ndarray) -> np.ndarray: - self.session.infer(inputs=[input_content]) - return self.session.get_output_tensor().data - - @staticmethod - def _verify_model(model_path): - model_path = Path(model_path) - if not model_path.exists(): - raise FileNotFoundError(f"{model_path} does not exists.") - if not model_path.is_file(): - raise FileExistsError(f"{model_path} is not a file.") - - -class LoadImage: - def __init__( - self, - ): - pass - - def __call__(self, img: InputType) -> np.ndarray: - if not isinstance(img, InputType.__args__): - raise LoadImageError( - f"The img type {type(img)} does not in {InputType.__args__}" - ) - - origin_img_type = type(img) - img = self.load_img(img) - img = self.convert_img(img, origin_img_type) - return img - - def load_img(self, img: InputType) -> np.ndarray: - if isinstance(img, (str, Path)): - self.verify_exist(img) - try: - img = self.img_to_ndarray(Image.open(img)) - except UnidentifiedImageError as e: - raise LoadImageError(f"cannot identify image file {img}") from e - return img - - if isinstance(img, bytes): - img = self.img_to_ndarray(Image.open(BytesIO(img))) - return img - - if isinstance(img, np.ndarray): - return img - - if isinstance(img, Image.Image): - return self.img_to_ndarray(img) - - raise LoadImageError(f"{type(img)} is not supported!") - - def img_to_ndarray(self, img: Image.Image) -> np.ndarray: - if img.mode == "1": - img = img.convert("L") - return np.array(img) - return np.array(img) - - def convert_img(self, img: np.ndarray, origin_img_type): - if img.ndim == 2: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if img.ndim == 3: - channel = img.shape[2] - if channel == 1: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if channel == 2: - return self.cvt_two_to_three(img) - - if channel == 3: - if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): - return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - return img - - if channel == 4: - return self.cvt_four_to_three(img) - - raise LoadImageError( - f"The channel({channel}) of the img is not in [1, 2, 3, 4]" - ) - - raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") - - @staticmethod - def cvt_two_to_three(img: np.ndarray) -> np.ndarray: - """gray + alpha → BGR""" - img_gray = img[..., 0] - img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) - - img_alpha = img[..., 1] - not_a = cv2.bitwise_not(img_alpha) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - """RGBA → BGR""" - r, g, b, a = cv2.split(img) - new_img = cv2.merge((b, g, r)) - - not_a = cv2.bitwise_not(a) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(new_img, new_img, mask=a) - - mean_color = np.mean(new_img) - if mean_color <= 0.0: - new_img = cv2.add(new_img, not_a) - else: - new_img = cv2.bitwise_not(new_img) - return new_img - - @staticmethod - def verify_exist(file_path: Union[str, Path]): - if not Path(file_path).exists(): - raise LoadImageError(f"{file_path} does not exist.") - - -class LoadImageError(Exception): - pass - - -def read_yaml(yaml_path): - with open(yaml_path, "rb") as f: - data = yaml.load(f, Loader=yaml.Loader) - return data - - -def update_model_path(config): - key = "model_path" - config["Det"][key] = str(root_dir / config["Det"][key]) - config["Rec"][key] = str(root_dir / config["Rec"][key]) - config["Cls"][key] = str(root_dir / config["Cls"][key]) - return config - - -def init_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-img", "--img_path", type=str, default=None, required=True) - parser.add_argument("-p", "--print_cost", action="store_true", default=False) - - global_group = parser.add_argument_group(title="Global") - global_group.add_argument("--text_score", type=float, default=0.5) - - global_group.add_argument("--no_det", action="store_true", default=False) - global_group.add_argument("--no_cls", action="store_true", default=False) - global_group.add_argument("--no_rec", action="store_true", default=False) - - global_group.add_argument("--print_verbose", action="store_true", default=False) - global_group.add_argument("--min_height", type=int, default=30) - global_group.add_argument("--width_height_ratio", type=int, default=8) - - global_group.add_argument("--inference_num_threads", type=int, default=-1) - - det_group = parser.add_argument_group(title="Det") - det_group.add_argument("--det_model_path", type=str, default=None) - det_group.add_argument("--det_limit_side_len", type=float, default=736) - det_group.add_argument( - "--det_limit_type", type=str, default="min", choices=["max", "min"] - ) - det_group.add_argument("--det_thresh", type=float, default=0.3) - det_group.add_argument("--det_box_thresh", type=float, default=0.5) - det_group.add_argument("--det_unclip_ratio", type=float, default=1.6) - det_group.add_argument( - "--det_donot_use_dilation", action="store_true", default=False - ) - det_group.add_argument( - "--det_score_mode", type=str, default="fast", choices=["slow", "fast"] - ) - - cls_group = parser.add_argument_group(title="Cls") - cls_group.add_argument("--cls_model_path", type=str, default=None) - cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) - cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) - cls_group.add_argument("--cls_batch_num", type=int, default=6) - cls_group.add_argument("--cls_thresh", type=float, default=0.9) - - rec_group = parser.add_argument_group(title="Rec") - rec_group.add_argument("--rec_model_path", type=str, default=None) - rec_group.add_argument("--rec_keys_path", type=str, default=None) - rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) - rec_group.add_argument("--rec_batch_num", type=int, default=6) - - vis_group = parser.add_argument_group(title="Visual Result") - vis_group.add_argument("-vis", "--vis_res", action="store_true", default=False) - vis_group.add_argument( - "--vis_font_path", - type=str, - default=None, - help="When -vis is True, the font_path must have value.", - ) - vis_group.add_argument( - "--vis_save_path", - type=str, - default=".", - help="The directory of saving the vis image.", - ) - - args = parser.parse_args() - return args - - -class UpdateParameters: - def __init__(self) -> None: - pass - - def parse_kwargs(self, **kwargs): - global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {} - for k, v in kwargs.items(): - if k.startswith("det"): - k = k.split("det_")[1] - if k == "donot_use_dilation": - k = "use_dilation" - v = not v - - det_dict[k] = v - elif k.startswith("cls"): - cls_dict[k] = v - elif k.startswith("rec"): - rec_dict[k] = v - else: - global_dict[k] = v - return global_dict, det_dict, cls_dict, rec_dict - - def __call__(self, config, **kwargs): - global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) - new_config = { - "Global": self.update_global_params(config["Global"], global_dict), - "Det": self.update_params(config["Det"], det_dict, "det_", None), - "Cls": self.update_params( - config["Cls"], - cls_dict, - "cls_", - ["cls_label_list", "cls_model_path", "cls_use_cuda"], - ), - "Rec": self.update_params( - config["Rec"], rec_dict, "rec_", ["rec_model_path", "rec_use_cuda"] - ), - } - - update_params = ["inference_num_threads"] - new_config = self.update_global_to_module( - config, update_params, src="Global", dsts=["Det", "Cls", "Rec"] - ) - return new_config - - def update_global_to_module( - self, config, params: List[str], src: str, dsts: List[str] - ): - for dst in dsts: - for param in params: - config[dst].update({param: config[src][param]}) - return config - - def update_global_params(self, config, global_dict): - if global_dict: - config.update(global_dict) - return config - - def update_params( - self, - config, - param_dict: Dict[str, str], - prefix: str, - need_remove_prefix: Optional[List[str]] = None, - ): - if not param_dict: - return config - - filter_dict = self.remove_prefix(param_dict, prefix, need_remove_prefix) - model_path = filter_dict.get("model_path", None) - if not model_path: - filter_dict["model_path"] = str(root_dir / config["model_path"]) - - config.update(filter_dict) - return config - - @staticmethod - def remove_prefix( - config: Dict[str, str], - prefix: str, - need_remove_prefix: Optional[List[str]] = None, - ) -> Dict[str, str]: - if not need_remove_prefix: - return config - - new_rec_dict = {} - for k, v in config.items(): - if k in need_remove_prefix: - k = k.split(prefix)[1] - new_rec_dict[k] = v - return new_rec_dict - - -class VisRes: - def __init__(self, text_score: float = 0.5): - self.text_score = text_score - self.load_img = LoadImage() - - def __call__( - self, - img_content: InputType, - dt_boxes: np.ndarray, - txts: Optional[Union[List[str], Tuple[str]]] = None, - scores: Optional[Tuple[float]] = None, - font_path: Optional[str] = None, - ) -> np.ndarray: - if txts is None and scores is None: - return self.draw_dt_boxes(img_content, dt_boxes) - return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores, font_path) - - def draw_dt_boxes(self, img_content: InputType, dt_boxes: np.ndarray) -> np.ndarray: - img = self.load_img(img_content) - - for idx, box in enumerate(dt_boxes): - color = self.get_random_color() - - points = np.array(box) - cv2.polylines(img, np.int32([points]), 1, color=color, thickness=1) - - start_point = round(points[0][0]), round(points[0][1]) - cv2.putText( - img, f"{idx}", start_point, cv2.FONT_HERSHEY_SIMPLEX, 1, color, 3 - ) - return img - - def draw_ocr_box_txt( - self, - img_content: InputType, - dt_boxes: np.ndarray, - txts: Optional[Union[List[str], Tuple[str]]] = None, - scores: Optional[Tuple[float]] = None, - font_path: Optional[str] = None, - ) -> np.ndarray: - font_path = self.get_font_path(font_path) - - image = Image.fromarray(self.load_img(img_content)) - h, w = image.height, image.width - if image.mode == "L": - image = image.convert("RGB") - - img_left = image.copy() - img_right = Image.new("RGB", (w, h), (255, 255, 255)) - - random.seed(0) - draw_left = ImageDraw.Draw(img_left) - draw_right = ImageDraw.Draw(img_right) - for idx, (box, txt) in enumerate(zip(dt_boxes, txts)): - if scores is not None and float(scores[idx]) < self.text_score: - continue - - color = self.get_random_color() - - box_list = np.array(box).reshape(8).tolist() - draw_left.polygon(box_list, fill=color) - draw_right.polygon(box_list, outline=color) - - box_height = self.get_box_height(box) - box_width = self.get_box_width(box) - if box_height > 2 * box_width: - font_size = max(int(box_width * 0.9), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - cur_y = box[0][1] - - for c in txt: - draw_right.text( - (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font - ) - cur_y += self.get_char_size(font, c) - else: - font_size = max(int(box_height * 0.8), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) - - img_left = Image.blend(image, img_left, 0.5) - img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) - img_show.paste(img_left, (0, 0, w, h)) - img_show.paste(img_right, (w, 0, w * 2, h)) - return np.array(img_show) - - @staticmethod - def get_font_path(font_path: Optional[Union[str, Path]] = None) -> str: - if font_path is None or not Path(font_path).exists(): - raise FileNotFoundError( - f"The {font_path} does not exists! \n" - f"You could download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing" - ) - return str(font_path) - - @staticmethod - def get_random_color() -> Tuple[int, int, int]: - return ( - random.randint(0, 255), - random.randint(0, 255), - random.randint(0, 255), - ) - - @staticmethod - def get_box_height(box: List[List[float]]) -> float: - return math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) - - @staticmethod - def get_box_width(box: List[List[float]]) -> float: - return math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) - - @staticmethod - def get_char_size(font, char_str: str) -> float: - # compatible with Pillow v9 and v10. - if hasattr(font, "getsize"): - get_size_func = getattr(font, "getsize") - return get_size_func(char_str)[1] - - if hasattr(font, "getlength"): - get_size_func = getattr(font, "getlength") - return get_size_func(char_str) - - raise ValueError( - "The Pillow ImageFont instance has not getsize or getlength func." - ) diff --git a/python/rapidocr_openvino/utils/__init__.py b/python/rapidocr_openvino/utils/__init__.py new file mode 100644 index 000000000..f3a7c471b --- /dev/null +++ b/python/rapidocr_openvino/utils/__init__.py @@ -0,0 +1,19 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from pathlib import Path +from typing import Dict, Union + +import yaml + +from .infer_engine import OpenVINOInferSession +from .load_image import LoadImage, LoadImageError +from .logger import get_logger +from .parse_parameters import UpdateParameters, init_args, update_model_path +from .vis_res import VisRes + + +def read_yaml(yaml_path: Union[str, Path]) -> Dict[str, Dict]: + with open(yaml_path, "rb") as f: + data = yaml.load(f, Loader=yaml.Loader) + return data diff --git a/python/rapidocr_openvino/utils/infer_engine.py b/python/rapidocr_openvino/utils/infer_engine.py new file mode 100644 index 000000000..fbd6da114 --- /dev/null +++ b/python/rapidocr_openvino/utils/infer_engine.py @@ -0,0 +1,45 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import os +import traceback +from pathlib import Path + +import numpy as np +from openvino.runtime import Core + + +class OpenVINOInferSession: + def __init__(self, config): + core = Core() + + self._verify_model(config["model_path"]) + model_onnx = core.read_model(config["model_path"]) + + cpu_nums = os.cpu_count() + infer_num_threads = config.get("inference_num_threads", -1) + if infer_num_threads != -1 and 1 <= infer_num_threads <= cpu_nums: + core.set_property("CPU", {"INFERENCE_NUM_THREADS": str(infer_num_threads)}) + + compile_model = core.compile_model(model=model_onnx, device_name="CPU") + self.session = compile_model.create_infer_request() + + def __call__(self, input_content: np.ndarray) -> np.ndarray: + try: + self.session.infer(inputs=[input_content]) + return self.session.get_output_tensor().data + except Exception as e: + error_info = traceback.format_exc() + raise OpenVIONError(error_info) from e + + @staticmethod + def _verify_model(model_path): + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exists.") + if not model_path.is_file(): + raise FileExistsError(f"{model_path} is not a file.") + + +class OpenVIONError(Exception): + pass diff --git a/python/rapidocr_openvino/utils/load_image.py b/python/rapidocr_openvino/utils/load_image.py new file mode 100644 index 000000000..056605303 --- /dev/null +++ b/python/rapidocr_openvino/utils/load_image.py @@ -0,0 +1,123 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from io import BytesIO +from pathlib import Path +from typing import Union + +import cv2 +import numpy as np +from PIL import Image, UnidentifiedImageError + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class LoadImage: + def __init__(self): + pass + + def __call__(self, img: InputType) -> np.ndarray: + if not isinstance(img, InputType.__args__): + raise LoadImageError( + f"The img type {type(img)} does not in {InputType.__args__}" + ) + + origin_img_type = type(img) + img = self.load_img(img) + img = self.convert_img(img, origin_img_type) + return img + + def load_img(self, img: InputType) -> np.ndarray: + if isinstance(img, (str, Path)): + self.verify_exist(img) + try: + img = self.img_to_ndarray(Image.open(img)) + except UnidentifiedImageError as e: + raise LoadImageError(f"cannot identify image file {img}") from e + return img + + if isinstance(img, bytes): + img = self.img_to_ndarray(Image.open(BytesIO(img))) + return img + + if isinstance(img, np.ndarray): + return img + + if isinstance(img, Image.Image): + return self.img_to_ndarray(img) + + raise LoadImageError(f"{type(img)} is not supported!") + + def img_to_ndarray(self, img: Image.Image) -> np.ndarray: + if img.mode == "1": + img = img.convert("L") + return np.array(img) + return np.array(img) + + def convert_img(self, img: np.ndarray, origin_img_type): + if img.ndim == 2: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if img.ndim == 3: + channel = img.shape[2] + if channel == 1: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if channel == 2: + return self.cvt_two_to_three(img) + + if channel == 3: + if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): + return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + if channel == 4: + return self.cvt_four_to_three(img) + + raise LoadImageError( + f"The channel({channel}) of the img is not in [1, 2, 3, 4]" + ) + + raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") + + @staticmethod + def cvt_two_to_three(img: np.ndarray) -> np.ndarray: + """gray + alpha → BGR""" + img_gray = img[..., 0] + img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) + + img_alpha = img[..., 1] + not_a = cv2.bitwise_not(img_alpha) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def cvt_four_to_three(img: np.ndarray) -> np.ndarray: + """RGBA → BGR""" + r, g, b, a = cv2.split(img) + new_img = cv2.merge((b, g, r)) + + not_a = cv2.bitwise_not(a) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(new_img, new_img, mask=a) + + mean_color = np.mean(new_img) + if mean_color <= 0.0: + new_img = cv2.add(new_img, not_a) + else: + new_img = cv2.bitwise_not(new_img) + return new_img + + @staticmethod + def verify_exist(file_path: Union[str, Path]): + if not Path(file_path).exists(): + raise LoadImageError(f"{file_path} does not exist.") + + +class LoadImageError(Exception): + pass diff --git a/python/rapidocr_openvino/utils/logger.py b/python/rapidocr_openvino/utils/logger.py new file mode 100644 index 000000000..ffd1cd04d --- /dev/null +++ b/python/rapidocr_openvino/utils/logger.py @@ -0,0 +1,19 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import logging + + +def get_logger(name: str): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger diff --git a/python/rapidocr_openvino/utils/parse_parameters.py b/python/rapidocr_openvino/utils/parse_parameters.py new file mode 100644 index 000000000..f229fd2e9 --- /dev/null +++ b/python/rapidocr_openvino/utils/parse_parameters.py @@ -0,0 +1,178 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import argparse +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +from PIL import Image + +root_dir = Path(__file__).resolve().parent.parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +def update_model_path(config): + key = "model_path" + config["Det"][key] = str(root_dir / config["Det"][key]) + config["Rec"][key] = str(root_dir / config["Rec"][key]) + config["Cls"][key] = str(root_dir / config["Cls"][key]) + return config + + +def init_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-img", "--img_path", type=str, default=None, required=True) + parser.add_argument("-p", "--print_cost", action="store_true", default=False) + + global_group = parser.add_argument_group(title="Global") + global_group.add_argument("--text_score", type=float, default=0.5) + + global_group.add_argument("--no_det", action="store_true", default=False) + global_group.add_argument("--no_cls", action="store_true", default=False) + global_group.add_argument("--no_rec", action="store_true", default=False) + + global_group.add_argument("--print_verbose", action="store_true", default=False) + global_group.add_argument("--min_height", type=int, default=30) + global_group.add_argument("--width_height_ratio", type=int, default=8) + + global_group.add_argument("--inference_num_threads", type=int, default=-1) + + det_group = parser.add_argument_group(title="Det") + det_group.add_argument("--det_model_path", type=str, default=None) + det_group.add_argument("--det_limit_side_len", type=float, default=736) + det_group.add_argument( + "--det_limit_type", type=str, default="min", choices=["max", "min"] + ) + det_group.add_argument("--det_thresh", type=float, default=0.3) + det_group.add_argument("--det_box_thresh", type=float, default=0.5) + det_group.add_argument("--det_unclip_ratio", type=float, default=1.6) + det_group.add_argument( + "--det_donot_use_dilation", action="store_true", default=False + ) + det_group.add_argument( + "--det_score_mode", type=str, default="fast", choices=["slow", "fast"] + ) + + cls_group = parser.add_argument_group(title="Cls") + cls_group.add_argument("--cls_model_path", type=str, default=None) + cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) + cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) + cls_group.add_argument("--cls_batch_num", type=int, default=6) + cls_group.add_argument("--cls_thresh", type=float, default=0.9) + + rec_group = parser.add_argument_group(title="Rec") + rec_group.add_argument("--rec_model_path", type=str, default=None) + rec_group.add_argument("--rec_keys_path", type=str, default=None) + rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) + rec_group.add_argument("--rec_batch_num", type=int, default=6) + + vis_group = parser.add_argument_group(title="Visual Result") + vis_group.add_argument("-vis", "--vis_res", action="store_true", default=False) + vis_group.add_argument( + "--vis_font_path", + type=str, + default=None, + help="When -vis is True, the font_path must have value.", + ) + vis_group.add_argument( + "--vis_save_path", + type=str, + default=".", + help="The directory of saving the vis image.", + ) + + args = parser.parse_args() + return args + + +class UpdateParameters: + def __init__(self) -> None: + pass + + def parse_kwargs(self, **kwargs): + global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {} + for k, v in kwargs.items(): + if k.startswith("det"): + k = k.split("det_")[1] + if k == "donot_use_dilation": + k = "use_dilation" + v = not v + + det_dict[k] = v + elif k.startswith("cls"): + cls_dict[k] = v + elif k.startswith("rec"): + rec_dict[k] = v + else: + global_dict[k] = v + return global_dict, det_dict, cls_dict, rec_dict + + def __call__(self, config, **kwargs): + global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) + new_config = { + "Global": self.update_global_params(config["Global"], global_dict), + "Det": self.update_params(config["Det"], det_dict, "det_", None), + "Cls": self.update_params( + config["Cls"], + cls_dict, + "cls_", + ["cls_label_list", "cls_model_path", "cls_use_cuda"], + ), + "Rec": self.update_params( + config["Rec"], rec_dict, "rec_", ["rec_model_path", "rec_use_cuda"] + ), + } + + update_params = ["inference_num_threads"] + new_config = self.update_global_to_module( + config, update_params, src="Global", dsts=["Det", "Cls", "Rec"] + ) + return new_config + + def update_global_to_module( + self, config, params: List[str], src: str, dsts: List[str] + ): + for dst in dsts: + for param in params: + config[dst].update({param: config[src][param]}) + return config + + def update_global_params(self, config, global_dict): + if global_dict: + config.update(global_dict) + return config + + def update_params( + self, + config, + param_dict: Dict[str, str], + prefix: str, + need_remove_prefix: Optional[List[str]] = None, + ): + if not param_dict: + return config + + filter_dict = self.remove_prefix(param_dict, prefix, need_remove_prefix) + model_path = filter_dict.get("model_path", None) + if not model_path: + filter_dict["model_path"] = str(root_dir / config["model_path"]) + + config.update(filter_dict) + return config + + @staticmethod + def remove_prefix( + config: Dict[str, str], + prefix: str, + need_remove_prefix: Optional[List[str]] = None, + ) -> Dict[str, str]: + if not need_remove_prefix: + return config + + new_rec_dict = {} + for k, v in config.items(): + if k in need_remove_prefix: + k = k.split(prefix)[1] + new_rec_dict[k] = v + return new_rec_dict diff --git a/python/rapidocr_openvino/utils/vis_res.py b/python/rapidocr_openvino/utils/vis_res.py new file mode 100644 index 000000000..405beabad --- /dev/null +++ b/python/rapidocr_openvino/utils/vis_res.py @@ -0,0 +1,143 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import math +import random +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from .load_image import LoadImage + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class VisRes: + def __init__(self, text_score: float = 0.5): + self.text_score = text_score + self.load_img = LoadImage() + + def __call__( + self, + img_content: InputType, + dt_boxes: np.ndarray, + txts: Optional[Union[List[str], Tuple[str]]] = None, + scores: Optional[Tuple[float]] = None, + font_path: Optional[str] = None, + ) -> np.ndarray: + if txts is None and scores is None: + return self.draw_dt_boxes(img_content, dt_boxes) + return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores, font_path) + + def draw_dt_boxes(self, img_content: InputType, dt_boxes: np.ndarray) -> np.ndarray: + img = self.load_img(img_content) + + for idx, box in enumerate(dt_boxes): + color = self.get_random_color() + + points = np.array(box) + cv2.polylines(img, np.int32([points]), 1, color=color, thickness=1) + + start_point = round(points[0][0]), round(points[0][1]) + cv2.putText( + img, f"{idx}", start_point, cv2.FONT_HERSHEY_SIMPLEX, 1, color, 3 + ) + return img + + def draw_ocr_box_txt( + self, + img_content: InputType, + dt_boxes: np.ndarray, + txts: Optional[Union[List[str], Tuple[str]]] = None, + scores: Optional[Tuple[float]] = None, + font_path: Optional[str] = None, + ) -> np.ndarray: + font_path = self.get_font_path(font_path) + + image = Image.fromarray(self.load_img(img_content)) + h, w = image.height, image.width + if image.mode == "L": + image = image.convert("RGB") + + img_left = image.copy() + img_right = Image.new("RGB", (w, h), (255, 255, 255)) + + random.seed(0) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(dt_boxes, txts)): + if scores is not None and float(scores[idx]) < self.text_score: + continue + + color = self.get_random_color() + + box_list = np.array(box).reshape(8).tolist() + draw_left.polygon(box_list, fill=color) + draw_right.polygon(box_list, outline=color) + + box_height = self.get_box_height(box) + box_width = self.get_box_width(box) + if box_height > 2 * box_width: + font_size = max(int(box_width * 0.9), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + cur_y = box[0][1] + + for c in txt: + draw_right.text( + (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font + ) + cur_y += self.get_char_size(font, c) + else: + font_size = max(int(box_height * 0.8), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + return np.array(img_show) + + @staticmethod + def get_font_path(font_path: Optional[Union[str, Path]] = None) -> str: + if font_path is None or not Path(font_path).exists(): + raise FileNotFoundError( + f"The {font_path} does not exists! \n" + f"You could download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing" + ) + return str(font_path) + + @staticmethod + def get_random_color() -> Tuple[int, int, int]: + return ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) + + @staticmethod + def get_box_height(box: List[List[float]]) -> float: + return math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) + + @staticmethod + def get_box_width(box: List[List[float]]) -> float: + return math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) + + @staticmethod + def get_char_size(font, char_str: str) -> float: + # compatible with Pillow v9 and v10. + if hasattr(font, "getsize"): + get_size_func = getattr(font, "getsize") + return get_size_func(char_str)[1] + + if hasattr(font, "getlength"): + get_size_func = getattr(font, "getlength") + return get_size_func(char_str) + + raise ValueError( + "The Pillow ImageFont instance has not getsize or getlength func." + ) From ef9202445af461e7853a42aece05ec4312f965a5 Mon Sep 17 00:00:00 2001 From: SWHL Date: Sat, 18 May 2024 16:31:23 +0800 Subject: [PATCH 7/8] refactor(rapidocr_paddle): Decoupling utils.py files --- python/rapidocr_paddle/utils.py | 549 ------------------ python/rapidocr_paddle/utils/__init__.py | 19 + python/rapidocr_paddle/utils/infer_engine.py | 121 ++++ python/rapidocr_paddle/utils/load_image.py | 123 ++++ python/rapidocr_paddle/utils/logger.py | 19 + .../rapidocr_paddle/utils/parse_parameters.py | 187 ++++++ python/rapidocr_paddle/utils/vis_res.py | 143 +++++ 7 files changed, 612 insertions(+), 549 deletions(-) delete mode 100644 python/rapidocr_paddle/utils.py create mode 100644 python/rapidocr_paddle/utils/__init__.py create mode 100644 python/rapidocr_paddle/utils/infer_engine.py create mode 100644 python/rapidocr_paddle/utils/load_image.py create mode 100644 python/rapidocr_paddle/utils/logger.py create mode 100644 python/rapidocr_paddle/utils/parse_parameters.py create mode 100644 python/rapidocr_paddle/utils/vis_res.py diff --git a/python/rapidocr_paddle/utils.py b/python/rapidocr_paddle/utils.py deleted file mode 100644 index b8135be5b..000000000 --- a/python/rapidocr_paddle/utils.py +++ /dev/null @@ -1,549 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import argparse -import math -import os -import platform -import random -import warnings -from io import BytesIO -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union - -import cv2 -import numpy as np -import paddle -import yaml -from paddle import inference -from PIL import Image, ImageDraw, ImageFont, UnidentifiedImageError - -root_dir = Path(__file__).resolve().parent -InputType = Union[str, np.ndarray, bytes, Path, Image.Image] - - -class PaddleInferSession: - def __init__(self, config, mode: Optional[str] = None) -> None: - self.mode = mode - - model_dir = Path(config["model_path"]) - pdmodel_path = model_dir / "inference.pdmodel" - pdiparams_path = model_dir / "inference.pdiparams" - - self._verify_model(pdmodel_path) - self._verify_model(pdiparams_path) - - infer_opts = inference.Config(str(pdmodel_path), str(pdiparams_path)) - - if config["use_cuda"]: - gpu_id = self.get_infer_gpuid() - if gpu_id is None: - warnings.warn( - "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson." - ) - infer_opts.enable_use_gpu(config["gpu_mem"], config["gpu_id"]) - else: - infer_opts.disable_gpu() - - cpu_nums = os.cpu_count() - infer_num_threads = config.get("cpu_math_library_num_threads", -1) - if infer_num_threads != -1 and 1 <= infer_num_threads <= cpu_nums: - infer_opts.set_cpu_math_library_num_threads(infer_num_threads) - - # enable memory optim - infer_opts.enable_memory_optim() - infer_opts.disable_glog_info() - infer_opts.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") - infer_opts.delete_pass("matmul_transpose_reshape_fuse_pass") - infer_opts.switch_use_feed_fetch_ops(False) - infer_opts.switch_ir_optim(True) - - self.predictor = inference.create_predictor(infer_opts) - - def __call__(self, img: np.ndarray): - input_tensor = self.get_input_tensors() - output_tensors = self.get_output_tensors() - - input_tensor.copy_from_cpu(img) - self.predictor.run() - - outputs = [] - for output_tensor in output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - - self.predictor.try_shrink_memory() - return outputs - - @staticmethod - def _verify_model(model_path): - model_path = Path(model_path) - if not model_path.exists(): - raise FileNotFoundError(f"{model_path} does not exists.") - if not model_path.is_file(): - raise FileExistsError(f"{model_path} is not a file.") - - def get_input_tensors( - self, - ): - input_names = self.predictor.get_input_names() - for name in input_names: - input_tensor = self.predictor.get_input_handle(name) - return input_tensor - - def get_output_tensors( - self, - ): - output_names = self.predictor.get_output_names() - if self.mode == "rec": - output_name = "softmax_0.tmp_0" - if output_name in output_names: - return [self.predictor.get_output_handle(output_name)] - - output_tensors = [] - for output_name in output_names: - output_tensor = self.predictor.get_output_handle(output_name) - output_tensors.append(output_tensor) - return output_tensors - - @staticmethod - def get_infer_gpuid(): - sysstr = platform.system() - if sysstr == "Windows": - return 0 - - if not paddle.device.is_compiled_with_rocm: - cmd = "env | grep CUDA_VISIBLE_DEVICES" - else: - cmd = "env | grep HIP_VISIBLE_DEVICES" - env_cuda = os.popen(cmd).readlines() - - if len(env_cuda) == 0: - return 0 - - gpu_id = env_cuda[0].strip().split("=")[1] - return int(gpu_id[0]) - - -class PaddleInferError(Exception): - pass - - -class LoadImage: - def __init__( - self, - ): - pass - - def __call__(self, img: InputType) -> np.ndarray: - if not isinstance(img, InputType.__args__): - raise LoadImageError( - f"The img type {type(img)} does not in {InputType.__args__}" - ) - - origin_img_type = type(img) - img = self.load_img(img) - img = self.convert_img(img, origin_img_type) - return img - - def load_img(self, img: InputType) -> np.ndarray: - if isinstance(img, (str, Path)): - self.verify_exist(img) - try: - img = self.img_to_ndarray(Image.open(img)) - except UnidentifiedImageError as e: - raise LoadImageError(f"cannot identify image file {img}") from e - return img - - if isinstance(img, bytes): - img = self.img_to_ndarray(Image.open(BytesIO(img))) - return img - - if isinstance(img, np.ndarray): - return img - - if isinstance(img, Image.Image): - return self.img_to_ndarray(img) - - raise LoadImageError(f"{type(img)} is not supported!") - - def img_to_ndarray(self, img: Image.Image) -> np.ndarray: - if img.mode == "1": - img = img.convert("L") - return np.array(img) - return np.array(img) - - def convert_img(self, img: np.ndarray, origin_img_type): - if img.ndim == 2: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if img.ndim == 3: - channel = img.shape[2] - if channel == 1: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if channel == 2: - return self.cvt_two_to_three(img) - - if channel == 3: - if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): - return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - return img - - if channel == 4: - return self.cvt_four_to_three(img) - - raise LoadImageError( - f"The channel({channel}) of the img is not in [1, 2, 3, 4]" - ) - - raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") - - @staticmethod - def cvt_two_to_three(img: np.ndarray) -> np.ndarray: - """gray + alpha → BGR""" - img_gray = img[..., 0] - img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) - - img_alpha = img[..., 1] - not_a = cv2.bitwise_not(img_alpha) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - """RGBA → BGR""" - r, g, b, a = cv2.split(img) - new_img = cv2.merge((b, g, r)) - - not_a = cv2.bitwise_not(a) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(new_img, new_img, mask=a) - - mean_color = np.mean(new_img) - if mean_color <= 0.0: - new_img = cv2.add(new_img, not_a) - else: - new_img = cv2.bitwise_not(new_img) - return new_img - - @staticmethod - def verify_exist(file_path: Union[str, Path]): - if not Path(file_path).exists(): - raise LoadImageError(f"{file_path} does not exist.") - - -class LoadImageError(Exception): - pass - - -def read_yaml(yaml_path): - with open(yaml_path, "rb") as f: - data = yaml.load(f, Loader=yaml.Loader) - return data - - -def update_model_path(config): - key = "model_path" - config["Det"][key] = str(root_dir / config["Det"][key]) - config["Rec"][key] = str(root_dir / config["Rec"][key]) - config["Cls"][key] = str(root_dir / config["Cls"][key]) - return config - - -def init_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-img", "--img_path", type=str, default=None, required=True) - parser.add_argument("-p", "--print_cost", action="store_true", default=False) - - global_group = parser.add_argument_group(title="Global") - global_group.add_argument("--text_score", type=float, default=0.5) - - global_group.add_argument("--no_det", action="store_true", default=False) - global_group.add_argument("--no_cls", action="store_true", default=False) - global_group.add_argument("--no_rec", action="store_true", default=False) - - global_group.add_argument("--print_verbose", action="store_true", default=False) - global_group.add_argument("--min_height", type=int, default=30) - global_group.add_argument("--width_height_ratio", type=int, default=8) - - global_group.add_argument("--cpu_math_library_num_threads", type=int, default=-1) - - det_group = parser.add_argument_group(title="Det") - det_group.add_argument("--det_use_cuda", action="store_true", default=False) - det_group.add_argument("--det_gpu_id", type=int, default=0) - det_group.add_argument("--det_gpu_mem", type=int, default=500) - det_group.add_argument("--det_model_path", type=str, default=None) - det_group.add_argument("--det_limit_side_len", type=float, default=736) - det_group.add_argument( - "--det_limit_type", type=str, default="min", choices=["max", "min"] - ) - det_group.add_argument("--det_thresh", type=float, default=0.3) - det_group.add_argument("--det_box_thresh", type=float, default=0.5) - det_group.add_argument("--det_unclip_ratio", type=float, default=1.6) - det_group.add_argument( - "--det_donot_use_dilation", action="store_true", default=False - ) - det_group.add_argument( - "--det_score_mode", type=str, default="fast", choices=["slow", "fast"] - ) - - cls_group = parser.add_argument_group(title="Cls") - cls_group.add_argument("--cls_use_cuda", action="store_true", default=False) - cls_group.add_argument("--cls_gpu_id", type=int, default=0) - cls_group.add_argument("--cls_gpu_mem", type=int, default=500) - cls_group.add_argument("--cls_model_path", type=str, default=None) - cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) - cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) - cls_group.add_argument("--cls_batch_num", type=int, default=6) - cls_group.add_argument("--cls_thresh", type=float, default=0.9) - - rec_group = parser.add_argument_group(title="Rec") - rec_group.add_argument("--rec_use_cuda", action="store_true", default=False) - rec_group.add_argument("--rec_gpu_id", type=int, default=0) - rec_group.add_argument("--rec_gpu_mem", type=int, default=500) - rec_group.add_argument("--rec_model_path", type=str, default=None) - rec_group.add_argument("--rec_keys_path", type=str, default=None) - rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) - rec_group.add_argument("--rec_batch_num", type=int, default=6) - - vis_group = parser.add_argument_group(title="Visual Result") - vis_group.add_argument("-vis", "--vis_res", action="store_true", default=False) - vis_group.add_argument( - "--vis_font_path", - type=str, - default=None, - help="When -vis is True, the font_path must have value.", - ) - vis_group.add_argument( - "--vis_save_path", - type=str, - default=".", - help="The directory of saving the vis image.", - ) - - args = parser.parse_args() - return args - - -class UpdateParameters: - def __init__(self) -> None: - pass - - def parse_kwargs(self, **kwargs): - global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {} - for k, v in kwargs.items(): - if k.startswith("det"): - k = k.split("det_")[1] - if k == "donot_use_dilation": - k = "use_dilation" - v = not v - - det_dict[k] = v - elif k.startswith("cls"): - cls_dict[k] = v - elif k.startswith("rec"): - rec_dict[k] = v - else: - global_dict[k] = v - return global_dict, det_dict, cls_dict, rec_dict - - def __call__(self, config, **kwargs): - global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) - new_config = { - "Global": self.update_global_params(config["Global"], global_dict), - "Det": self.update_params(config["Det"], det_dict, "det_", None), - "Cls": self.update_params( - config["Cls"], - cls_dict, - "cls_", - ["cls_label_list", "cls_model_path", "cls_use_cuda"], - ), - "Rec": self.update_params( - config["Rec"], rec_dict, "rec_", ["rec_model_path", "rec_use_cuda"] - ), - } - - update_params = ["cpu_math_library_num_threads"] - new_config = self.update_global_to_module( - config, update_params, src="Global", dsts=["Det", "Cls", "Rec"] - ) - return new_config - - def update_global_to_module( - self, config, params: List[str], src: str, dsts: List[str] - ): - for dst in dsts: - for param in params: - config[dst].update({param: config[src][param]}) - return config - - def update_global_params(self, config, global_dict): - if global_dict: - config.update(global_dict) - return config - - def update_params( - self, - config, - param_dict: Dict[str, str], - prefix: str, - need_remove_prefix: Optional[List[str]] = None, - ): - if not param_dict: - return config - - filter_dict = self.remove_prefix(param_dict, prefix, need_remove_prefix) - model_path = filter_dict.get("model_path", None) - if not model_path: - filter_dict["model_path"] = str(root_dir / config["model_path"]) - - config.update(filter_dict) - return config - - @staticmethod - def remove_prefix( - config: Dict[str, str], - prefix: str, - need_remove_prefix: Optional[List[str]] = None, - ) -> Dict[str, str]: - if not need_remove_prefix: - return config - - new_rec_dict = {} - for k, v in config.items(): - if k in need_remove_prefix: - k = k.split(prefix)[1] - new_rec_dict[k] = v - return new_rec_dict - - -class VisRes: - def __init__(self, text_score: float = 0.5): - self.text_score = text_score - self.load_img = LoadImage() - - def __call__( - self, - img_content: InputType, - dt_boxes: np.ndarray, - txts: Optional[Union[List[str], Tuple[str]]] = None, - scores: Optional[Tuple[float]] = None, - font_path: Optional[str] = None, - ) -> np.ndarray: - if txts is None and scores is None: - return self.draw_dt_boxes(img_content, dt_boxes) - return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores, font_path) - - def draw_dt_boxes(self, img_content: InputType, dt_boxes: np.ndarray) -> np.ndarray: - img = self.load_img(img_content) - - for idx, box in enumerate(dt_boxes): - color = self.get_random_color() - - points = np.array(box) - cv2.polylines(img, np.int32([points]), 1, color=color, thickness=1) - - start_point = round(points[0][0]), round(points[0][1]) - cv2.putText( - img, f"{idx}", start_point, cv2.FONT_HERSHEY_SIMPLEX, 1, color, 3 - ) - return img - - def draw_ocr_box_txt( - self, - img_content: InputType, - dt_boxes: np.ndarray, - txts: Optional[Union[List[str], Tuple[str]]] = None, - scores: Optional[Tuple[float]] = None, - font_path: Optional[str] = None, - ) -> np.ndarray: - font_path = self.get_font_path(font_path) - - image = Image.fromarray(self.load_img(img_content)) - h, w = image.height, image.width - if image.mode == "L": - image = image.convert("RGB") - - img_left = image.copy() - img_right = Image.new("RGB", (w, h), (255, 255, 255)) - - random.seed(0) - draw_left = ImageDraw.Draw(img_left) - draw_right = ImageDraw.Draw(img_right) - for idx, (box, txt) in enumerate(zip(dt_boxes, txts)): - if scores is not None and float(scores[idx]) < self.text_score: - continue - - color = self.get_random_color() - - box_list = np.array(box).reshape(8).tolist() - draw_left.polygon(box_list, fill=color) - draw_right.polygon(box_list, outline=color) - - box_height = self.get_box_height(box) - box_width = self.get_box_width(box) - if box_height > 2 * box_width: - font_size = max(int(box_width * 0.9), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - cur_y = box[0][1] - - for c in txt: - draw_right.text( - (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font - ) - cur_y += self.get_char_size(font, c) - else: - font_size = max(int(box_height * 0.8), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) - - img_left = Image.blend(image, img_left, 0.5) - img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) - img_show.paste(img_left, (0, 0, w, h)) - img_show.paste(img_right, (w, 0, w * 2, h)) - return np.array(img_show) - - @staticmethod - def get_font_path(font_path: Optional[Union[str, Path]] = None) -> str: - if font_path is None or not Path(font_path).exists(): - raise FileNotFoundError( - f"The {font_path} does not exists! \n" - f"You could download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing" - ) - return str(font_path) - - @staticmethod - def get_random_color() -> Tuple[int, int, int]: - return ( - random.randint(0, 255), - random.randint(0, 255), - random.randint(0, 255), - ) - - @staticmethod - def get_box_height(box: List[List[float]]) -> float: - return math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) - - @staticmethod - def get_box_width(box: List[List[float]]) -> float: - return math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) - - @staticmethod - def get_char_size(font, char_str: str) -> float: - # compatible with Pillow v9 and v10. - if hasattr(font, "getsize"): - get_size_func = getattr(font, "getsize") - return get_size_func(char_str)[1] - - if hasattr(font, "getlength"): - get_size_func = getattr(font, "getlength") - return get_size_func(char_str) - - raise ValueError( - "The Pillow ImageFont instance has not getsize or getlength func." - ) diff --git a/python/rapidocr_paddle/utils/__init__.py b/python/rapidocr_paddle/utils/__init__.py new file mode 100644 index 000000000..84031f485 --- /dev/null +++ b/python/rapidocr_paddle/utils/__init__.py @@ -0,0 +1,19 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from pathlib import Path +from typing import Dict, Union + +import yaml + +from .infer_engine import PaddleInferSession +from .load_image import LoadImage, LoadImageError +from .logger import get_logger +from .parse_parameters import UpdateParameters, init_args, update_model_path +from .vis_res import VisRes + + +def read_yaml(yaml_path: Union[str, Path]) -> Dict[str, Dict]: + with open(yaml_path, "rb") as f: + data = yaml.load(f, Loader=yaml.Loader) + return data diff --git a/python/rapidocr_paddle/utils/infer_engine.py b/python/rapidocr_paddle/utils/infer_engine.py new file mode 100644 index 000000000..0a4de3c6a --- /dev/null +++ b/python/rapidocr_paddle/utils/infer_engine.py @@ -0,0 +1,121 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import os +import platform +from pathlib import Path +from typing import Optional + +import numpy as np +import paddle +from paddle import inference + +from .logger import get_logger + + +class PaddleInferSession: + def __init__(self, config, mode: Optional[str] = None) -> None: + self.logger = get_logger("PaddleInferSession") + self.mode = mode + + model_dir = Path(config["model_path"]) + pdmodel_path = model_dir / "inference.pdmodel" + pdiparams_path = model_dir / "inference.pdiparams" + + self._verify_model(pdmodel_path) + self._verify_model(pdiparams_path) + + infer_opts = inference.Config(str(pdmodel_path), str(pdiparams_path)) + + if config["use_cuda"]: + gpu_id = self.get_infer_gpuid() + if gpu_id is None: + self.logger.warning( + "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson." + ) + infer_opts.enable_use_gpu(config["gpu_mem"], config["gpu_id"]) + else: + infer_opts.disable_gpu() + + cpu_nums = os.cpu_count() + infer_num_threads = config.get("cpu_math_library_num_threads", -1) + if infer_num_threads != -1 and 1 <= infer_num_threads <= cpu_nums: + infer_opts.set_cpu_math_library_num_threads(infer_num_threads) + + # enable memory optim + infer_opts.enable_memory_optim() + infer_opts.disable_glog_info() + infer_opts.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") + infer_opts.delete_pass("matmul_transpose_reshape_fuse_pass") + infer_opts.switch_use_feed_fetch_ops(False) + infer_opts.switch_ir_optim(True) + + self.predictor = inference.create_predictor(infer_opts) + + def __call__(self, img: np.ndarray): + input_tensor = self.get_input_tensors() + output_tensors = self.get_output_tensors() + + input_tensor.copy_from_cpu(img) + self.predictor.run() + + outputs = [] + for output_tensor in output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + + self.predictor.try_shrink_memory() + return outputs + + @staticmethod + def _verify_model(model_path): + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exists.") + if not model_path.is_file(): + raise FileExistsError(f"{model_path} is not a file.") + + def get_input_tensors( + self, + ): + input_names = self.predictor.get_input_names() + for name in input_names: + input_tensor = self.predictor.get_input_handle(name) + return input_tensor + + def get_output_tensors( + self, + ): + output_names = self.predictor.get_output_names() + if self.mode == "rec": + output_name = "softmax_0.tmp_0" + if output_name in output_names: + return [self.predictor.get_output_handle(output_name)] + + output_tensors = [] + for output_name in output_names: + output_tensor = self.predictor.get_output_handle(output_name) + output_tensors.append(output_tensor) + return output_tensors + + @staticmethod + def get_infer_gpuid(): + sysstr = platform.system() + if sysstr == "Windows": + return 0 + + if not paddle.device.is_compiled_with_rocm: + cmd = "env | grep CUDA_VISIBLE_DEVICES" + else: + cmd = "env | grep HIP_VISIBLE_DEVICES" + env_cuda = os.popen(cmd).readlines() + + if len(env_cuda) == 0: + return 0 + + gpu_id = env_cuda[0].strip().split("=")[1] + return int(gpu_id[0]) + + +class PaddleInferError(Exception): + pass diff --git a/python/rapidocr_paddle/utils/load_image.py b/python/rapidocr_paddle/utils/load_image.py new file mode 100644 index 000000000..056605303 --- /dev/null +++ b/python/rapidocr_paddle/utils/load_image.py @@ -0,0 +1,123 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from io import BytesIO +from pathlib import Path +from typing import Union + +import cv2 +import numpy as np +from PIL import Image, UnidentifiedImageError + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class LoadImage: + def __init__(self): + pass + + def __call__(self, img: InputType) -> np.ndarray: + if not isinstance(img, InputType.__args__): + raise LoadImageError( + f"The img type {type(img)} does not in {InputType.__args__}" + ) + + origin_img_type = type(img) + img = self.load_img(img) + img = self.convert_img(img, origin_img_type) + return img + + def load_img(self, img: InputType) -> np.ndarray: + if isinstance(img, (str, Path)): + self.verify_exist(img) + try: + img = self.img_to_ndarray(Image.open(img)) + except UnidentifiedImageError as e: + raise LoadImageError(f"cannot identify image file {img}") from e + return img + + if isinstance(img, bytes): + img = self.img_to_ndarray(Image.open(BytesIO(img))) + return img + + if isinstance(img, np.ndarray): + return img + + if isinstance(img, Image.Image): + return self.img_to_ndarray(img) + + raise LoadImageError(f"{type(img)} is not supported!") + + def img_to_ndarray(self, img: Image.Image) -> np.ndarray: + if img.mode == "1": + img = img.convert("L") + return np.array(img) + return np.array(img) + + def convert_img(self, img: np.ndarray, origin_img_type): + if img.ndim == 2: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if img.ndim == 3: + channel = img.shape[2] + if channel == 1: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if channel == 2: + return self.cvt_two_to_three(img) + + if channel == 3: + if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): + return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + if channel == 4: + return self.cvt_four_to_three(img) + + raise LoadImageError( + f"The channel({channel}) of the img is not in [1, 2, 3, 4]" + ) + + raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") + + @staticmethod + def cvt_two_to_three(img: np.ndarray) -> np.ndarray: + """gray + alpha → BGR""" + img_gray = img[..., 0] + img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) + + img_alpha = img[..., 1] + not_a = cv2.bitwise_not(img_alpha) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def cvt_four_to_three(img: np.ndarray) -> np.ndarray: + """RGBA → BGR""" + r, g, b, a = cv2.split(img) + new_img = cv2.merge((b, g, r)) + + not_a = cv2.bitwise_not(a) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(new_img, new_img, mask=a) + + mean_color = np.mean(new_img) + if mean_color <= 0.0: + new_img = cv2.add(new_img, not_a) + else: + new_img = cv2.bitwise_not(new_img) + return new_img + + @staticmethod + def verify_exist(file_path: Union[str, Path]): + if not Path(file_path).exists(): + raise LoadImageError(f"{file_path} does not exist.") + + +class LoadImageError(Exception): + pass diff --git a/python/rapidocr_paddle/utils/logger.py b/python/rapidocr_paddle/utils/logger.py new file mode 100644 index 000000000..ffd1cd04d --- /dev/null +++ b/python/rapidocr_paddle/utils/logger.py @@ -0,0 +1,19 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import logging + + +def get_logger(name: str): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger diff --git a/python/rapidocr_paddle/utils/parse_parameters.py b/python/rapidocr_paddle/utils/parse_parameters.py new file mode 100644 index 000000000..dcf457960 --- /dev/null +++ b/python/rapidocr_paddle/utils/parse_parameters.py @@ -0,0 +1,187 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import argparse +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +from PIL import Image + +root_dir = Path(__file__).resolve().parent.parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +def update_model_path(config): + key = "model_path" + config["Det"][key] = str(root_dir / config["Det"][key]) + config["Rec"][key] = str(root_dir / config["Rec"][key]) + config["Cls"][key] = str(root_dir / config["Cls"][key]) + return config + + +def init_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-img", "--img_path", type=str, default=None, required=True) + parser.add_argument("-p", "--print_cost", action="store_true", default=False) + + global_group = parser.add_argument_group(title="Global") + global_group.add_argument("--text_score", type=float, default=0.5) + + global_group.add_argument("--no_det", action="store_true", default=False) + global_group.add_argument("--no_cls", action="store_true", default=False) + global_group.add_argument("--no_rec", action="store_true", default=False) + + global_group.add_argument("--print_verbose", action="store_true", default=False) + global_group.add_argument("--min_height", type=int, default=30) + global_group.add_argument("--width_height_ratio", type=int, default=8) + + global_group.add_argument("--cpu_math_library_num_threads", type=int, default=-1) + + det_group = parser.add_argument_group(title="Det") + det_group.add_argument("--det_use_cuda", action="store_true", default=False) + det_group.add_argument("--det_gpu_id", type=int, default=0) + det_group.add_argument("--det_gpu_mem", type=int, default=500) + det_group.add_argument("--det_model_path", type=str, default=None) + det_group.add_argument("--det_limit_side_len", type=float, default=736) + det_group.add_argument( + "--det_limit_type", type=str, default="min", choices=["max", "min"] + ) + det_group.add_argument("--det_thresh", type=float, default=0.3) + det_group.add_argument("--det_box_thresh", type=float, default=0.5) + det_group.add_argument("--det_unclip_ratio", type=float, default=1.6) + det_group.add_argument( + "--det_donot_use_dilation", action="store_true", default=False + ) + det_group.add_argument( + "--det_score_mode", type=str, default="fast", choices=["slow", "fast"] + ) + + cls_group = parser.add_argument_group(title="Cls") + cls_group.add_argument("--cls_use_cuda", action="store_true", default=False) + cls_group.add_argument("--cls_gpu_id", type=int, default=0) + cls_group.add_argument("--cls_gpu_mem", type=int, default=500) + cls_group.add_argument("--cls_model_path", type=str, default=None) + cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192]) + cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"]) + cls_group.add_argument("--cls_batch_num", type=int, default=6) + cls_group.add_argument("--cls_thresh", type=float, default=0.9) + + rec_group = parser.add_argument_group(title="Rec") + rec_group.add_argument("--rec_use_cuda", action="store_true", default=False) + rec_group.add_argument("--rec_gpu_id", type=int, default=0) + rec_group.add_argument("--rec_gpu_mem", type=int, default=500) + rec_group.add_argument("--rec_model_path", type=str, default=None) + rec_group.add_argument("--rec_keys_path", type=str, default=None) + rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320]) + rec_group.add_argument("--rec_batch_num", type=int, default=6) + + vis_group = parser.add_argument_group(title="Visual Result") + vis_group.add_argument("-vis", "--vis_res", action="store_true", default=False) + vis_group.add_argument( + "--vis_font_path", + type=str, + default=None, + help="When -vis is True, the font_path must have value.", + ) + vis_group.add_argument( + "--vis_save_path", + type=str, + default=".", + help="The directory of saving the vis image.", + ) + + args = parser.parse_args() + return args + + +class UpdateParameters: + def __init__(self) -> None: + pass + + def parse_kwargs(self, **kwargs): + global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {} + for k, v in kwargs.items(): + if k.startswith("det"): + k = k.split("det_")[1] + if k == "donot_use_dilation": + k = "use_dilation" + v = not v + + det_dict[k] = v + elif k.startswith("cls"): + cls_dict[k] = v + elif k.startswith("rec"): + rec_dict[k] = v + else: + global_dict[k] = v + return global_dict, det_dict, cls_dict, rec_dict + + def __call__(self, config, **kwargs): + global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs) + new_config = { + "Global": self.update_global_params(config["Global"], global_dict), + "Det": self.update_params(config["Det"], det_dict, "det_", None), + "Cls": self.update_params( + config["Cls"], + cls_dict, + "cls_", + ["cls_label_list", "cls_model_path", "cls_use_cuda"], + ), + "Rec": self.update_params( + config["Rec"], rec_dict, "rec_", ["rec_model_path", "rec_use_cuda"] + ), + } + + update_params = ["cpu_math_library_num_threads"] + new_config = self.update_global_to_module( + config, update_params, src="Global", dsts=["Det", "Cls", "Rec"] + ) + return new_config + + def update_global_to_module( + self, config, params: List[str], src: str, dsts: List[str] + ): + for dst in dsts: + for param in params: + config[dst].update({param: config[src][param]}) + return config + + def update_global_params(self, config, global_dict): + if global_dict: + config.update(global_dict) + return config + + def update_params( + self, + config, + param_dict: Dict[str, str], + prefix: str, + need_remove_prefix: Optional[List[str]] = None, + ): + if not param_dict: + return config + + filter_dict = self.remove_prefix(param_dict, prefix, need_remove_prefix) + model_path = filter_dict.get("model_path", None) + if not model_path: + filter_dict["model_path"] = str(root_dir / config["model_path"]) + + config.update(filter_dict) + return config + + @staticmethod + def remove_prefix( + config: Dict[str, str], + prefix: str, + need_remove_prefix: Optional[List[str]] = None, + ) -> Dict[str, str]: + if not need_remove_prefix: + return config + + new_rec_dict = {} + for k, v in config.items(): + if k in need_remove_prefix: + k = k.split(prefix)[1] + new_rec_dict[k] = v + return new_rec_dict diff --git a/python/rapidocr_paddle/utils/vis_res.py b/python/rapidocr_paddle/utils/vis_res.py new file mode 100644 index 000000000..405beabad --- /dev/null +++ b/python/rapidocr_paddle/utils/vis_res.py @@ -0,0 +1,143 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import math +import random +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from .load_image import LoadImage + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class VisRes: + def __init__(self, text_score: float = 0.5): + self.text_score = text_score + self.load_img = LoadImage() + + def __call__( + self, + img_content: InputType, + dt_boxes: np.ndarray, + txts: Optional[Union[List[str], Tuple[str]]] = None, + scores: Optional[Tuple[float]] = None, + font_path: Optional[str] = None, + ) -> np.ndarray: + if txts is None and scores is None: + return self.draw_dt_boxes(img_content, dt_boxes) + return self.draw_ocr_box_txt(img_content, dt_boxes, txts, scores, font_path) + + def draw_dt_boxes(self, img_content: InputType, dt_boxes: np.ndarray) -> np.ndarray: + img = self.load_img(img_content) + + for idx, box in enumerate(dt_boxes): + color = self.get_random_color() + + points = np.array(box) + cv2.polylines(img, np.int32([points]), 1, color=color, thickness=1) + + start_point = round(points[0][0]), round(points[0][1]) + cv2.putText( + img, f"{idx}", start_point, cv2.FONT_HERSHEY_SIMPLEX, 1, color, 3 + ) + return img + + def draw_ocr_box_txt( + self, + img_content: InputType, + dt_boxes: np.ndarray, + txts: Optional[Union[List[str], Tuple[str]]] = None, + scores: Optional[Tuple[float]] = None, + font_path: Optional[str] = None, + ) -> np.ndarray: + font_path = self.get_font_path(font_path) + + image = Image.fromarray(self.load_img(img_content)) + h, w = image.height, image.width + if image.mode == "L": + image = image.convert("RGB") + + img_left = image.copy() + img_right = Image.new("RGB", (w, h), (255, 255, 255)) + + random.seed(0) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(dt_boxes, txts)): + if scores is not None and float(scores[idx]) < self.text_score: + continue + + color = self.get_random_color() + + box_list = np.array(box).reshape(8).tolist() + draw_left.polygon(box_list, fill=color) + draw_right.polygon(box_list, outline=color) + + box_height = self.get_box_height(box) + box_width = self.get_box_width(box) + if box_height > 2 * box_width: + font_size = max(int(box_width * 0.9), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + cur_y = box[0][1] + + for c in txt: + draw_right.text( + (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font + ) + cur_y += self.get_char_size(font, c) + else: + font_size = max(int(box_height * 0.8), 10) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + return np.array(img_show) + + @staticmethod + def get_font_path(font_path: Optional[Union[str, Path]] = None) -> str: + if font_path is None or not Path(font_path).exists(): + raise FileNotFoundError( + f"The {font_path} does not exists! \n" + f"You could download the file in the https://drive.google.com/file/d/1evWVX38EFNwTq_n5gTFgnlv8tdaNcyIA/view?usp=sharing" + ) + return str(font_path) + + @staticmethod + def get_random_color() -> Tuple[int, int, int]: + return ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) + + @staticmethod + def get_box_height(box: List[List[float]]) -> float: + return math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) + + @staticmethod + def get_box_width(box: List[List[float]]) -> float: + return math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) + + @staticmethod + def get_char_size(font, char_str: str) -> float: + # compatible with Pillow v9 and v10. + if hasattr(font, "getsize"): + get_size_func = getattr(font, "getsize") + return get_size_func(char_str)[1] + + if hasattr(font, "getlength"): + get_size_func = getattr(font, "getlength") + return get_size_func(char_str) + + raise ValueError( + "The Pillow ImageFont instance has not getsize or getlength func." + ) From 23b657e8260c6aef34baff183bda231f22d85ee6 Mon Sep 17 00:00:00 2001 From: SWHL Date: Sat, 18 May 2024 16:36:26 +0800 Subject: [PATCH 8/8] ci(rapidocr_openvino): Change the trigger mode to file change trigger --- .../gen_whl_to_pypi_rapidocr_vino.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/gen_whl_to_pypi_rapidocr_vino.yml b/.github/workflows/gen_whl_to_pypi_rapidocr_vino.yml index cb3419d05..db6957bc6 100644 --- a/.github/workflows/gen_whl_to_pypi_rapidocr_vino.yml +++ b/.github/workflows/gen_whl_to_pypi_rapidocr_vino.yml @@ -2,14 +2,14 @@ name: Push rapidocr_openvino to pypi on: push: - # branches: [ main ] - # paths: - # - 'python/rapidocr_openvino/**' - # - 'docs/doc_whl_rapidocr_vino.md' - # - 'python/setup_openvino.py' - # - '.github/workflows/gen_whl_to_pypi_rapidocr_vino.yml' - tags: - - v* + branches: [ main ] + paths: + - 'python/rapidocr_openvino/**' + - 'docs/doc_whl_rapidocr_vino.md' + - 'python/setup_openvino.py' + - '.github/workflows/gen_whl_to_pypi_rapidocr_vino.yml' + # tags: + # - v* env: RESOURCES_URL: https://github.com/RapidAI/RapidOCR/releases/download/v1.1.0/required_for_whl_v1.3.0.zip @@ -86,7 +86,7 @@ jobs: cd .. python -m pip install --upgrade pip - python setup_openvino.py bdist_wheel ${{ github.ref_name }} + python setup_openvino.py bdist_wheel "{{ github.event.head_commit.message }}" mv dist ../ - name: Publish distribution 📦 to PyPI