Skip to content

Commit

Permalink
sync disable orientation predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Sep 20, 2024
1 parent aa9e96d commit 3c95f79
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 42 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ model = ocr_predictor(
# Additional parameters - meta information
detect_orientation=False, # set to `True` if the orientation of the pages should be detected (default: False)
detect_language=False, # set to `True` if the language of the pages should be detected (default: False)
# Orientation specific parameters in combination with `assume_straight_pages=False` and/or `straighten_pages=True`
disable_crop_orientation=False, # set to `True` if the crop orientation classification should be disabled (default: False)
disable_page_orientation=False, # set to `True` if the general page orientation classification should be disabled (default: False)
# DocumentBuilder specific parameters
resolve_lines=True, # whether words should be automatically grouped into lines (default: True)
resolve_blocks=False, # whether lines should be automatically grouped into blocks (default: False)
Expand Down
12 changes: 8 additions & 4 deletions onnxtr/models/classification/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, List, Union
from typing import Any, List, Optional, Union

import numpy as np
from scipy.special import softmax
Expand All @@ -29,10 +29,10 @@ class OrientationPredictor(NestedObject):

def __init__(
self,
pre_processor: PreProcessor,
model: Any,
pre_processor: Optional[PreProcessor],
model: Optional[Any],
) -> None:
self.pre_processor = pre_processor
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
self.model = model

def __call__(
Expand All @@ -43,6 +43,10 @@ def __call__(
if any(input.ndim != 3 for input in inputs):
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")

if self.model is None or self.pre_processor is None:
# predictor is disabled
return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]

processed_batches = self.pre_processor(inputs)
predicted_batches = [self.model(batch) for batch in processed_batches]

Expand Down
11 changes: 10 additions & 1 deletion onnxtr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,17 @@


def _orientation_predictor(
arch: Any, model_type: str, load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any
arch: Any,
model_type: str,
load_in_8_bit: bool = False,
engine_cfg: Optional[EngineConfig] = None,
disabled: bool = False,
**kwargs: Any,
) -> OrientationPredictor:
if disabled:
# Case where the orientation predictor is disabled
return OrientationPredictor(None, None)

if isinstance(arch, str):
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
Expand Down
30 changes: 21 additions & 9 deletions onnxtr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,19 @@ def __init__(
) -> None:
self.assume_straight_pages = assume_straight_pages
self.straighten_pages = straighten_pages
self._page_orientation_disabled = kwargs.pop("disable_page_orientation", False)
self._crop_orientation_disabled = kwargs.pop("disable_crop_orientation", False)
self.crop_orientation_predictor = (
None
if assume_straight_pages
else crop_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
else crop_orientation_predictor(
load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg, disabled=self._crop_orientation_disabled
)
)
self.page_orientation_predictor = (
page_orientation_predictor(load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg)
page_orientation_predictor(
load_in_8_bit=load_in_8_bit, engine_cfg=clf_engine_cfg, disabled=self._crop_orientation_disabled
)
if detect_orientation or straighten_pages or not assume_straight_pages
else None
)
Expand Down Expand Up @@ -123,13 +129,18 @@ def _generate_crops(
loc_preds: List[np.ndarray],
channels_last: bool,
assume_straight_pages: bool = False,
assume_horizontal: bool = False,
) -> List[List[np.ndarray]]:
extraction_fn = extract_crops if assume_straight_pages else extract_rcrops

crops = [
extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator]
for page, _boxes in zip(pages, loc_preds)
]
if assume_straight_pages:
crops = [
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
for page, _boxes in zip(pages, loc_preds)
]
else:
crops = [
extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
for page, _boxes in zip(pages, loc_preds)
]
return crops

@staticmethod
Expand All @@ -138,8 +149,9 @@ def _prepare_crops(
loc_preds: List[np.ndarray],
channels_last: bool,
assume_straight_pages: bool = False,
assume_horizontal: bool = False,
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)

# Avoid sending zero-sized crops
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
Expand Down
2 changes: 2 additions & 0 deletions onnxtr/models/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __call__(
]
if self.detect_orientation:
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
print(general_pages_orientations, origin_pages_orientations)
orientations = [
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
]
Expand Down Expand Up @@ -119,6 +120,7 @@ def __call__(
loc_preds, # type: ignore[arg-type]
channels_last=True,
assume_straight_pages=self.assume_straight_pages,
assume_horizontal=self._page_orientation_disabled,
)
# Rectify crop orientation and get crop orientation predictions
crop_orientations: Any = []
Expand Down
95 changes: 76 additions & 19 deletions onnxtr/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True


def extract_rcrops(
img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True, assume_horizontal: bool = False
) -> List[np.ndarray]:
"""Created cropped images from list of rotated bounding boxes
Expand All @@ -481,6 +481,7 @@ def extract_rcrops(
polys: bounding boxes of shape (N, 4, 2)
dtype: target data type of bounding boxes
channels_last: whether the channel dimensions is the last one instead of the last one
assume_horizontal: whether the boxes are assumed to be only horizontally oriented
Returns:
-------
Expand All @@ -498,22 +499,78 @@ def extract_rcrops(
_boxes[:, :, 0] *= width
_boxes[:, :, 1] *= height

src_pts = _boxes[:, :3].astype(np.float32)
# Preserve size
d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1)
d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1)
# (N, 3, 2)
dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype)
dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1
dst_pts[:, 2, 1] = d2 - 1
# Use a warp transformation to extract the crop
crops = [
cv2.warpAffine(
img if channels_last else img.transpose(1, 2, 0),
# Transformation matrix
cv2.getAffineTransform(src_pts[idx], dst_pts[idx]),
(int(d1[idx]), int(d2[idx])),
)
for idx in range(_boxes.shape[0])
]
src_img = img if channels_last else img.transpose(1, 2, 0)

# Handle only horizontal oriented boxes
if assume_horizontal:
crops = []

for box in _boxes:
# Sort the points according to the x-axis
box_points = box[np.argsort(box[:, 0])]

# Divide the points into left and right
left_points = box_points[:2]
right_points = box_points[2:]

# Sort the left points according to the y-axis
left_points = left_points[np.argsort(left_points[:, 1])]
# Sort the right points according to the y-axis
right_points = right_points[np.argsort(right_points[:, 1])]
box_points = np.concatenate([left_points, right_points])

# Get the width and height of the rectangle that will contain the warped quadrilateral
# Designate the width and height based on maximum side of the quadrilateral
width_upper = np.linalg.norm(box_points[0] - box_points[2])
width_lower = np.linalg.norm(box_points[1] - box_points[3])
height_left = np.linalg.norm(box_points[0] - box_points[1])
height_right = np.linalg.norm(box_points[2] - box_points[3])

# Get the maximum width and height
rect_width = max(int(width_upper), int(width_lower))
rect_height = max(int(height_left), int(height_right))

dst_pts = np.array(
[
[0, 0], # top-left
# bottom-left
[0, rect_height - 1],
# top-right
[rect_width - 1, 0],
# bottom-right
[rect_width - 1, rect_height - 1],
],
dtype=dtype,
)

# Get the perspective transform matrix using the box points
affine_mat = cv2.getPerspectiveTransform(box_points.astype(np.float32), dst_pts)

# Perform the perspective warp to get the rectified crop
crop = cv2.warpPerspective(src_img, affine_mat, (rect_width, rect_height))

# Add the crop to the list of crops
crops.append(crop)

# Handle any oriented boxes
else:
src_pts = _boxes[:, :3].astype(np.float32)
# Preserve size
d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1)
d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1)
# (N, 3, 2)
dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype)
dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1
dst_pts[:, 2, 1] = d2 - 1
# Use a warp transformation to extract the crop
crops = [
cv2.warpAffine(
src_img,
# Transformation matrix
cv2.getAffineTransform(src_pts[idx], dst_pts[idx]),
(int(d1[idx]), int(d2[idx])),
)
for idx in range(_boxes.shape[0])
]

return crops # type: ignore[return-value]
16 changes: 16 additions & 0 deletions tests/common/test_models_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def test_crop_orientation_model(mock_text_box, quantized):
with pytest.raises(ValueError):
_ = classification.crop_orientation_predictor(detection.db_resnet34())

# Test with disabled predictor
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_crop_orientation", disabled=True)
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90]) == [
[0, 0, 0, 0],
[0, 0, 0, 0],
[1.0, 1.0, 1.0, 1.0],
]


@pytest.mark.parametrize("quantized", [False, True])
def test_page_orientation_model(mock_payslip, quantized):
Expand All @@ -109,3 +117,11 @@ def test_page_orientation_model(mock_payslip, quantized):

with pytest.raises(ValueError):
_ = classification.page_orientation_predictor(detection.db_resnet34())

# Test with disabled predictor
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_page_orientation", disabled=True)
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90]) == [
[0, 0, 0, 0],
[0, 0, 0, 0],
[1.0, 1.0, 1.0, 1.0],
]
23 changes: 18 additions & 5 deletions tests/common/test_models_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@ def __call__(self, loc_preds):


@pytest.mark.parametrize(
"assume_straight_pages, straighten_pages",
"assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation",
[
[True, False],
[False, False],
[True, True],
[True, False, False, False],
[False, False, True, True],
[True, True, False, False],
[False, True, True, True],
[True, False, True, False],
],
)
def test_ocrpredictor(mock_pdf, assume_straight_pages, straighten_pages):
def test_ocrpredictor(
mock_pdf, assume_straight_pages, straighten_pages, disable_page_orientation, disable_crop_orientation
):
det_bsize = 4
det_predictor = DetectionPredictor(
PreProcessor(output_size=(1024, 1024), batch_size=det_bsize),
Expand All @@ -56,6 +60,15 @@ def test_ocrpredictor(mock_pdf, assume_straight_pages, straighten_pages):
detect_language=True,
resolve_lines=True,
resolve_blocks=True,
disable_page_orientation=disable_page_orientation,
disable_crop_orientation=disable_crop_orientation,
)

assert (
predictor._page_orientation_disabled if disable_page_orientation else not predictor._page_orientation_disabled
)
assert (
predictor._crop_orientation_disabled if disable_crop_orientation else not predictor._crop_orientation_disabled
)

if assume_straight_pages:
Expand Down
9 changes: 5 additions & 4 deletions tests/common/test_utils_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def test_extract_crops(mock_pdf):
assert geometry.extract_crops(doc_img, np.zeros((0, 4))) == []


def test_extract_rcrops(mock_pdf):
@pytest.mark.parametrize("assume_horizontal", [True, False])
def test_extract_rcrops(mock_pdf, assume_horizontal):
doc_img = DocumentFile.from_pdf(mock_pdf)[0]
num_crops = 2
rel_boxes = np.array(
Expand All @@ -280,17 +281,17 @@ def test_extract_rcrops(mock_pdf):
abs_boxes = abs_boxes.astype(np.int64)

with pytest.raises(AssertionError):
geometry.extract_rcrops(doc_img, np.zeros((1, 8)))
geometry.extract_rcrops(doc_img, np.zeros((1, 8)), assume_horizontal=assume_horizontal)
for boxes in (rel_boxes, abs_boxes):
croped_imgs = geometry.extract_rcrops(doc_img, boxes)
croped_imgs = geometry.extract_rcrops(doc_img, boxes, assume_horizontal=assume_horizontal)
# Number of crops
assert len(croped_imgs) == num_crops
# Data type and shape
assert all(isinstance(crop, np.ndarray) for crop in croped_imgs)
assert all(crop.ndim == 3 for crop in croped_imgs)

# No box
assert geometry.extract_rcrops(doc_img, np.zeros((0, 4, 2))) == []
assert geometry.extract_rcrops(doc_img, np.zeros((0, 4, 2)), assume_horizontal=assume_horizontal) == []


@pytest.mark.parametrize(
Expand Down

0 comments on commit 3c95f79

Please sign in to comment.