Skip to content

Commit

Permalink
refactor(rapidocr_openvino): Refactor the entire code
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jun 28, 2024
1 parent 1de4db0 commit 8cae2e2
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 392 deletions.
13 changes: 7 additions & 6 deletions python/rapidocr_openvino/ch_ppocr_v2_cls/text_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
import math
import time
from typing import List
from typing import Any, Dict, List, Tuple, Union

import cv2
import numpy as np
Expand All @@ -26,15 +26,17 @@


class TextClassifier:
def __init__(self, config):
def __init__(self, config: Dict[str, Any]):
self.cls_image_shape = config["cls_image_shape"]
self.cls_batch_num = config["cls_batch_num"]
self.cls_thresh = config["cls_thresh"]
self.postprocess_op = ClsPostProcess(config["label_list"])

self.infer = OpenVINOInferSession(config)

def __call__(self, img_list: List[np.ndarray]):
def __call__(
self, img_list: Union[np.ndarray, List[np.ndarray]]
) -> Tuple[List[np.ndarray], List[List[Union[str, float]]], float]:
if isinstance(img_list, np.ndarray):
img_list = [img_list]

Expand Down Expand Up @@ -65,16 +67,15 @@ def __call__(self, img_list: List[np.ndarray]):
cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime

for rno in range(len(cls_result)):
label, score = cls_result[rno]
for rno, (label, score) in enumerate(cls_result):
cls_res[indices[beg_img_no + rno]] = [label, score]
if "180" in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1
)
return img_list, cls_res, elapse

def resize_norm_img(self, img):
def resize_norm_img(self, img: np.ndarray) -> np.ndarray:
img_c, img_h, img_w = self.cls_image_shape
h, w = img.shape[:2]
ratio = w / float(h)
Expand Down
18 changes: 8 additions & 10 deletions python/rapidocr_openvino/ch_ppocr_v2_cls/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class ClsPostProcess:
"""Convert between text-label and text-index"""
from typing import List, Tuple

import numpy as np


def __init__(self, label_list):
super(ClsPostProcess, self).__init__()
class ClsPostProcess:
def __init__(self, label_list: List[str]):
self.label_list = label_list

def __call__(self, preds, label=None):
def __call__(self, preds: np.ndarray) -> List[Tuple[str, float]]:
pred_idxs = preds.argmax(axis=1)
decode_out = [
(self.label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
]
if label is None:
return decode_out

label = [(self.label_list[idx], 1.0) for idx in label]
return decode_out, label
return decode_out
117 changes: 45 additions & 72 deletions python/rapidocr_openvino/ch_ppocr_v3_det/text_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- encoding: utf-8 -*-
import argparse
# @Author: SWHL
# @Contact: [email protected]
import time
from typing import Any, Dict, Optional, Tuple

import cv2
import numpy as np

from rapidocr_openvino.utils import OpenVINOInferSession, read_yaml
from rapidocr_openvino.utils import OpenVINOInferSession

from .utils import DBPostProcess, create_operators, transform
from .utils import DBPostProcess, DetPreProcess


class TextDetector:
def __init__(self, config):
pre_process_list = {
"DetResizeForTest": {
"limit_side_len": config.get("limit_side_len", 736),
"limit_type": config.get("limit_type", "min"),
},
"NormalizeImage": {
"std": [0.229, 0.224, 0.225],
"mean": [0.485, 0.456, 0.406],
"scale": "1./255.",
"order": "hwc",
},
"ToCHWImage": None,
"KeepKeys": {"keep_keys": ["image", "shape"]},
}
self.preprocess_op = create_operators(pre_process_list)
def __init__(self, config: Dict[str, Any]):
limit_side_len = config.get("limit_side_len", 736)
limit_type = config.get("limit_type", "min")
self.preprocess_op = DetPreProcess(limit_side_len, limit_type)

post_process = {
"thresh": config.get("thresh", 0.3),
Expand All @@ -53,29 +42,45 @@ def __init__(self, config):

self.infer = OpenVINOInferSession(config)

def __call__(self, img):
ori_im = img.copy()
data = {"image": img}
data = transform(data, self.preprocess_op)
img, shape_list = data
def __call__(self, img: np.ndarray) -> Tuple[Optional[np.ndarray], float]:
start_time = time.perf_counter()

if img is None:
return None, 0
raise ValueError("img is None")

img = np.expand_dims(img, axis=0).astype(np.float32)
shape_list = np.expand_dims(shape_list, axis=0)
ori_img_shape = img.shape[0], img.shape[1]
prepro_img = self.preprocess_op(img)
if prepro_img is None:
return None, 0

starttime = time.time()
preds = self.infer(img)
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]["points"]
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime
preds = self.infer(prepro_img)
dt_boxes, dt_boxes_scores = self.postprocess_op(preds, ori_img_shape)
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_img_shape)
elapse = time.perf_counter() - start_time
return dt_boxes, elapse

def order_points_clockwise(self, pts):
def filter_tag_det_res(
self, dt_boxes: np.ndarray, image_shape: Tuple[int, int]
) -> np.ndarray:
img_height, img_width = image_shape
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)

rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue

dt_boxes_new.append(box)
return np.array(dt_boxes_new)

def order_points_clockwise(self, pts: np.ndarray) -> np.ndarray:
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
# sort the points based on their x-coordinates
reference from:
https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
sort the points based on their x-coordinates
"""
xSorted = pts[np.argsort(pts[:, 0]), :]

Expand All @@ -96,42 +101,10 @@ def order_points_clockwise(self, pts):
rect = np.array([tl, tr, br, bl], dtype="float32")
return rect

def clip_det_res(self, points, img_height, img_width):
def clip_det_res(
self, points: np.ndarray, img_height: int, img_width: int
) -> np.ndarray:
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points

def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, default="config.yaml")
parser.add_argument("--image_path", type=str, default=None)
args = parser.parse_args()

config = read_yaml(args.config_path)

text_detector = TextDetector(config)

img = cv2.imread(args.image_path)
dt_boxes, elapse = text_detector(img)

from utils import draw_text_det_res

src_im = draw_text_det_res(dt_boxes, args.image_path)
cv2.imwrite("det_results.jpg", src_im)
print("The det_results.jpg has been saved in the current directory.")
Loading

0 comments on commit 8cae2e2

Please sign in to comment.