Skip to content

Commit

Permalink
refactor(rapidocr_paddle): Refactor the entire code
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jun 28, 2024
1 parent 8cae2e2 commit c01cc08
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 388 deletions.
18 changes: 8 additions & 10 deletions python/rapidocr_paddle/ch_ppocr_v2_cls/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class ClsPostProcess:
"""Convert between text-label and text-index"""
from typing import List, Tuple

import numpy as np


def __init__(self, label_list):
super(ClsPostProcess, self).__init__()
class ClsPostProcess:
def __init__(self, label_list: List[str]):
self.label_list = label_list

def __call__(self, preds, label=None):
def __call__(self, preds: np.ndarray) -> List[Tuple[str, float]]:
pred_idxs = preds.argmax(axis=1)
decode_out = [
(self.label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)
]
if label is None:
return decode_out

label = [(self.label_list[idx], 1.0) for idx in label]
return decode_out, label
return decode_out
109 changes: 37 additions & 72 deletions python/rapidocr_paddle/ch_ppocr_v3_det/text_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,21 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import argparse
import time
from typing import Any, Dict, Optional, Tuple

import cv2
import numpy as np

from rapidocr_paddle.utils import PaddleInferSession, read_yaml
from rapidocr_paddle.utils import PaddleInferSession

from .utils import DBPostProcess, create_operators, transform
from .utils import DBPostProcess, DetPreProcess


class TextDetector:
def __init__(self, config):
pre_process_list = {
"DetResizeForTest": {
"limit_side_len": config.get("limit_side_len", 736),
"limit_type": config.get("limit_type", "min"),
},
"NormalizeImage": {
"std": [0.229, 0.224, 0.225],
"mean": [0.485, 0.456, 0.406],
"scale": "1./255.",
"order": "hwc",
},
"ToCHWImage": None,
"KeepKeys": {"keep_keys": ["image", "shape"]},
}
self.preprocess_op = create_operators(pre_process_list)
def __init__(self, config: Dict[str, Any]):
limit_side_len = config.get("limit_side_len", 736)
limit_type = config.get("limit_type", "min")
self.preprocess_op = DetPreProcess(limit_side_len, limit_type)

post_process = {
"thresh": config.get("thresh", 0.3),
Expand All @@ -55,31 +42,41 @@ def __init__(self, config):

self.infer = PaddleInferSession(config)

def __call__(self, img):
def __call__(self, img: np.ndarray) -> Tuple[Optional[np.ndarray], float]:
start_time = time.perf_counter()

if img is None:
raise ValueError("img is None")

ori_im_shape = img.shape[:2]

data = {"image": img}
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
ori_img_shape = img.shape[0], img.shape[1]
prepro_img = self.preprocess_op(img)
if prepro_img is None:
return None, 0

img = np.expand_dims(img, axis=0).astype(np.float32)
shape_list = np.expand_dims(shape_list, axis=0)
preds = self.infer(prepro_img)[0]
dt_boxes, dt_boxes_scores = self.postprocess_op(preds, ori_img_shape)
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_img_shape)
elapse = time.perf_counter() - start_time
return dt_boxes, elapse

starttime = time.time()
preds = self.infer(img)[0]
post_result = self.postprocess_op(preds, shape_list)
def filter_tag_det_res(
self, dt_boxes: np.ndarray, image_shape: Tuple[int, int]
) -> np.ndarray:
img_height, img_width = image_shape
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)

dt_boxes = post_result[0]["points"]
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im_shape)
elapse = time.time() - starttime
return dt_boxes, elapse
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue

dt_boxes_new.append(box)
return np.array(dt_boxes_new)

def order_points_clockwise(self, pts):
def order_points_clockwise(self, pts: np.ndarray) -> np.ndarray:
"""
reference from:
https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
Expand All @@ -104,42 +101,10 @@ def order_points_clockwise(self, pts):
rect = np.array([tl, tr, br, bl], dtype="float32")
return rect

def clip_det_res(self, points, img_height, img_width):
def clip_det_res(
self, points: np.ndarray, img_height: int, img_width: int
) -> np.ndarray:
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points

def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, default="config.yaml")
parser.add_argument("--image_path", type=str, default=None)
args = parser.parse_args()

config = read_yaml(args.config_path)

text_detector = TextDetector(config)

img = cv2.imread(args.image_path)
dt_boxes, elapse = text_detector(img)

from utils import draw_text_det_res

src_im = draw_text_det_res(dt_boxes, args.image_path)
cv2.imwrite("det_results.jpg", src_im)
print("The det_results.jpg has been saved in the current directory.")
Loading

0 comments on commit c01cc08

Please sign in to comment.