From ff8cce473e59cf86a2eb62ef72afae9762c3a4ca Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Tue, 19 Nov 2024 10:51:37 +0100 Subject: [PATCH] [docs] Fix docstrings (#49) --- .pre-commit-config.yaml | 2 +- onnxtr/contrib/__init__.py | 1 + onnxtr/contrib/artefacts.py | 2 -- onnxtr/contrib/base.py | 9 ----- onnxtr/file_utils.py | 1 - onnxtr/io/elements.py | 16 +-------- onnxtr/io/html.py | 2 -- onnxtr/io/image.py | 2 -- onnxtr/io/pdf.py | 2 -- onnxtr/io/reader.py | 6 ---- onnxtr/models/_utils.py | 6 ---- onnxtr/models/builder.py | 13 ------- .../models/classification/models/mobilenet.py | 5 --- .../models/classification/predictor/base.py | 1 - onnxtr/models/classification/zoo.py | 4 --- onnxtr/models/detection/_utils/base.py | 3 -- onnxtr/models/detection/core.py | 7 +--- .../models/differentiable_binarization.py | 7 ---- onnxtr/models/detection/models/fast.py | 7 ---- onnxtr/models/detection/models/linknet.py | 7 ---- onnxtr/models/detection/postprocessor/base.py | 7 +--- onnxtr/models/detection/predictor/base.py | 1 - onnxtr/models/detection/zoo.py | 2 -- onnxtr/models/engine.py | 2 -- onnxtr/models/factory/hub.py | 6 +--- onnxtr/models/predictor/base.py | 2 -- onnxtr/models/predictor/predictor.py | 1 - onnxtr/models/preprocessor/base.py | 5 --- onnxtr/models/recognition/core.py | 1 - onnxtr/models/recognition/models/crnn.py | 12 ------- onnxtr/models/recognition/models/master.py | 6 ---- onnxtr/models/recognition/models/parseq.py | 4 --- onnxtr/models/recognition/models/sar.py | 4 --- onnxtr/models/recognition/models/vitstr.py | 6 ---- onnxtr/models/recognition/predictor/_utils.py | 2 -- onnxtr/models/recognition/predictor/base.py | 1 - onnxtr/models/recognition/utils.py | 4 --- onnxtr/models/recognition/zoo.py | 2 -- onnxtr/models/zoo.py | 2 -- onnxtr/transforms/base.py | 16 +++++++-- onnxtr/utils/data.py | 5 +-- onnxtr/utils/fonts.py | 2 -- onnxtr/utils/geometry.py | 34 ++----------------- onnxtr/utils/multithreading.py | 3 -- onnxtr/utils/reconstitution.py | 2 -- onnxtr/utils/visualization.py | 9 ----- pyproject.toml | 2 +- tests/common/test_transforms.py | 2 +- 48 files changed, 25 insertions(+), 223 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6bdde4..5b283e9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: no-commit-to-branch args: ['--branch', 'main'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.4 hooks: - id: ruff args: [ --fix ] diff --git a/onnxtr/contrib/__init__.py b/onnxtr/contrib/__init__.py index e69de29..f721da3 100644 --- a/onnxtr/contrib/__init__.py +++ b/onnxtr/contrib/__init__.py @@ -0,0 +1 @@ +from .artefacts import ArtefactDetector \ No newline at end of file diff --git a/onnxtr/contrib/artefacts.py b/onnxtr/contrib/artefacts.py index dcde0ac..e70e377 100644 --- a/onnxtr/contrib/artefacts.py +++ b/onnxtr/contrib/artefacts.py @@ -34,7 +34,6 @@ class ArtefactDetector(_BasePredictor): >>> results = detector(doc) Args: - ---- arch: the architecture to use batch_size: the batch size to use model_path: the path to the model to use @@ -109,7 +108,6 @@ def show(self, **kwargs: Any) -> None: Display the results Args: - ---- **kwargs: additional keyword arguments to be passed to `plt.show` """ requires_package("matplotlib", "`.show()` requires matplotlib installed") diff --git a/onnxtr/contrib/base.py b/onnxtr/contrib/base.py index 8990bad..4f12b09 100644 --- a/onnxtr/contrib/base.py +++ b/onnxtr/contrib/base.py @@ -16,7 +16,6 @@ class _BasePredictor: Base class for all predictors Args: - ---- batch_size: the batch size to use url: the url to use to download a model if needed model_path: the path to the model to use @@ -35,13 +34,11 @@ def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = Non Download the model from the given url if needed Args: - ---- url: the url to use model_path: the path to the model to use **kwargs: additional arguments to be passed to `download_from_url` Returns: - ------- Any: the ONNX loaded model """ if not url and not model_path: @@ -54,11 +51,9 @@ def preprocess(self, img: np.ndarray) -> np.ndarray: Preprocess the input image Args: - ---- img: the input image to preprocess Returns: - ------- np.ndarray: the preprocessed image """ raise NotImplementedError @@ -68,12 +63,10 @@ def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarr Postprocess the model output Args: - ---- output: the model output to postprocess input_images: the input images used to generate the output Returns: - ------- Any: the postprocessed output """ raise NotImplementedError @@ -83,11 +76,9 @@ def __call__(self, inputs: List[np.ndarray]) -> Any: Call the model on the given inputs Args: - ---- inputs: the inputs to use Returns: - ------- Any: the postprocessed output """ self._inputs = inputs diff --git a/onnxtr/file_utils.py b/onnxtr/file_utils.py index 3905a6f..cb28bfd 100644 --- a/onnxtr/file_utils.py +++ b/onnxtr/file_utils.py @@ -19,7 +19,6 @@ def requires_package(name: str, extra_message: Optional[str] = None) -> None: # package requirement helper Args: - ---- name: name of the package extra_message: additional message to display if the package is not found """ diff --git a/onnxtr/io/elements.py b/onnxtr/io/elements.py index fe4f78a..6ff19e1 100644 --- a/onnxtr/io/elements.py +++ b/onnxtr/io/elements.py @@ -62,7 +62,6 @@ class Word(Element): """Implements a word element Args: - ---- value: the text string of the word confidence: the confidence associated with the text prediction geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to @@ -106,7 +105,6 @@ class Artefact(Element): """Implements a non-textual element Args: - ---- artefact_type: the type of artefact confidence: the confidence of the type prediction geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to @@ -139,7 +137,6 @@ class Line(Element): """Implements a line element as a collection of words Args: - ---- words: list of word elements geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing @@ -186,7 +183,6 @@ class Block(Element): """Implements a block element as a collection of lines and artefacts Args: - ---- lines: list of line elements artefacts: list of artefacts geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to @@ -240,7 +236,6 @@ class Page(Element): """Implements a page element as a collection of blocks Args: - ---- page: image encoded as a numpy array in uint8 blocks: list of block elements page_idx: the index of the page in the input raw document @@ -295,11 +290,9 @@ def synthesize(self, **kwargs) -> np.ndarray: """Synthesize the page from the predictions Args: - ---- **kwargs: keyword arguments passed to the `synthesize_page` method Returns - ------- synthesized page """ return synthesize_page(self.export(), **kwargs) @@ -309,11 +302,9 @@ def export_as_xml(self, file_title: str = "OnnxTR - XML export (hOCR)") -> Tuple convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md Args: - ---- file_title: the title of the XML file Returns: - ------- a tuple of the XML byte string, and its ElementTree """ p_idx = self.page_idx @@ -421,7 +412,6 @@ class Document(Element): """Implements a document element as a collection of pages Args: - ---- pages: list of page elements """ @@ -447,11 +437,9 @@ def synthesize(self, **kwargs) -> List[np.ndarray]: """Synthesize all pages from their predictions Args: - ---- **kwargs: keyword arguments passed to the `Page.synthesize` method - Returns - ------- + Returns: list of synthesized pages """ return [page.synthesize(**kwargs) for page in self.pages] @@ -460,11 +448,9 @@ def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]: """Export the document as XML (hOCR-format) Args: - ---- **kwargs: additional keyword arguments passed to the Page.export_as_xml method Returns: - ------- list of tuple of (bytes, ElementTree) """ return [page.export_as_xml(**kwargs) for page in self.pages] diff --git a/onnxtr/io/html.py b/onnxtr/io/html.py index e3d9269..4715d7c 100644 --- a/onnxtr/io/html.py +++ b/onnxtr/io/html.py @@ -15,12 +15,10 @@ def read_html(url: str, **kwargs: Any) -> bytes: >>> doc = read_html("https://www.yoursite.com") Args: - ---- url: URL of the target web page **kwargs: keyword arguments from `weasyprint.HTML` Returns: - ------- decoded PDF file as a bytes stream """ from weasyprint import HTML diff --git a/onnxtr/io/image.py b/onnxtr/io/image.py index 8c44e91..0c221c1 100644 --- a/onnxtr/io/image.py +++ b/onnxtr/io/image.py @@ -25,13 +25,11 @@ def read_img_as_numpy( >>> page = read_img_as_numpy("path/to/your/doc.jpg") Args: - ---- file: the path to the image file output_size: the expected output size of each page in format H x W rgb_output: whether the output ndarray channel order should be RGB instead of BGR. Returns: - ------- the page decoded as numpy ndarray of shape H x W x 3 """ if isinstance(file, (str, Path)): diff --git a/onnxtr/io/pdf.py b/onnxtr/io/pdf.py index d5e7b47..c2305ba 100644 --- a/onnxtr/io/pdf.py +++ b/onnxtr/io/pdf.py @@ -26,7 +26,6 @@ def read_pdf( >>> doc = read_pdf("path/to/your/doc.pdf") Args: - ---- file: the path to the PDF file scale: rendering scale (1 corresponds to 72dpi) rgb_mode: if True, the output will be RGB, otherwise BGR @@ -34,7 +33,6 @@ def read_pdf( **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` Returns: - ------- the list of pages decoded as numpy ndarray of shape H x W x C """ # Rasterise pages to numpy ndarrays with pypdfium2 diff --git a/onnxtr/io/reader.py b/onnxtr/io/reader.py index 68290e1..85c0d84 100644 --- a/onnxtr/io/reader.py +++ b/onnxtr/io/reader.py @@ -29,12 +29,10 @@ def from_pdf(cls, file: AbstractFile, **kwargs) -> List[np.ndarray]: >>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf") Args: - ---- file: the path to the PDF file or a binary stream **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` Returns: - ------- the list of pages decoded as numpy ndarray of shape H x W x 3 """ return read_pdf(file, **kwargs) @@ -47,12 +45,10 @@ def from_url(cls, url: str, **kwargs) -> List[np.ndarray]: >>> doc = DocumentFile.from_url("https://www.yoursite.com") Args: - ---- url: the URL of the target web page **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` Returns: - ------- the list of pages decoded as numpy ndarray of shape H x W x 3 """ requires_package( @@ -71,12 +67,10 @@ def from_images(cls, files: Union[Sequence[AbstractFile], AbstractFile], **kwarg >>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"]) Args: - ---- files: the path to the image file or a binary stream, or a collection of those **kwargs: additional parameters to :meth:`onnxtr.io.image.read_img_as_numpy` Returns: - ------- the list of pages decoded as numpy ndarray of shape H x W x 3 """ if isinstance(files, (str, Path, bytes)): diff --git a/onnxtr/models/_utils.py b/onnxtr/models/_utils.py index f77e2ba..4be7d9d 100644 --- a/onnxtr/models/_utils.py +++ b/onnxtr/models/_utils.py @@ -20,11 +20,9 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float: """Get the maximum shape ratio of a contour. Args: - ---- contour: the contour from cv2.findContour Returns: - ------- the maximum shape ratio """ _, (w, h), _ = cv2.minAreaRect(contour) @@ -43,7 +41,6 @@ def estimate_orientation( lines of the document and the assumption that they should be horizontal. Args: - ---- img: the img or bitmap to analyze (H, W, C) general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence) estimated by a model @@ -53,7 +50,6 @@ def estimate_orientation( lower_area: the minimum area of a contour to be considered Returns: - ------- the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation) """ assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported" @@ -162,11 +158,9 @@ def get_language(text: str) -> Tuple[str, float]: Get the language with the highest probability or no language if only a few words or a low probability Args: - ---- text (str): text Returns: - ------- The detected language in ISO 639 code and confidence score """ try: diff --git a/onnxtr/models/builder.py b/onnxtr/models/builder.py index b2c2451..6c63ccd 100644 --- a/onnxtr/models/builder.py +++ b/onnxtr/models/builder.py @@ -20,7 +20,6 @@ class DocumentBuilder(NestedObject): """Implements a document builder Args: - ---- resolve_lines: whether words should be automatically grouped into lines resolve_blocks: whether lines should be automatically grouped into blocks paragraph_break: relative length of the minimum space separating paragraphs @@ -45,11 +44,9 @@ def _sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Sort bounding boxes from top to bottom, left to right Args: - ---- boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox) Returns: - ------- tuple: indices of ordered boxes of shape (N,), boxes If straight boxes are passed tpo the function, boxes are unchanged else: boxes returned are straight boxes fitted to the straightened rotated boxes @@ -69,12 +66,10 @@ def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[Li """Split a line in sub_lines Args: - ---- boxes: bounding boxes of shape (N, 4) word_idcs: list of indexes for the words of the line Returns: - ------- A list of (sub-)lines computed from the original line (words) """ lines = [] @@ -109,11 +104,9 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]: """Order boxes to group them in lines Args: - ---- boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox Returns: - ------- nested list of box indices """ # Sort boxes, and straighten the boxes if they are rotated @@ -157,12 +150,10 @@ def _resolve_blocks(boxes: np.ndarray, lines: List[List[int]]) -> List[List[List """Order lines to group them in blocks Args: - ---- boxes: bounding boxes of shape (N, 4) or (N, 4, 2) lines: list of lines, each line is a list of idx Returns: - ------- nested list of box indices """ # Resolve enclosing boxes of lines @@ -230,7 +221,6 @@ def _build_blocks( """Gather independent words in structured blocks Args: - ---- boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2) objectness_scores: objectness scores of all detected words of the page, of shape N word_preds: list of all detected words of the page, of shape N @@ -238,7 +228,6 @@ def _build_blocks( the general orientation (orientations + confidences) of the crops Returns: - ------- list of block elements """ if boxes.shape[0] != len(word_preds): @@ -307,7 +296,6 @@ def __call__( """Re-arrange detected words into structured blocks Args: - ---- pages: list of N elements, where each element represents the page image boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4) or (*, 4, 2) for all words for a given page @@ -322,7 +310,6 @@ def __call__( where each element is a dictionary containing the language (language + confidence) Returns: - ------- document object """ if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len( diff --git a/onnxtr/models/classification/models/mobilenet.py b/onnxtr/models/classification/models/mobilenet.py index d371c71..bf233bf 100644 --- a/onnxtr/models/classification/models/mobilenet.py +++ b/onnxtr/models/classification/models/mobilenet.py @@ -42,7 +42,6 @@ class MobileNetV3(Engine): """MobileNetV3 Onnx loader Args: - ---- model_path: path or url to onnx model file engine_cfg: configuration for the inference engine cfg: configuration dictionary @@ -97,14 +96,12 @@ def mobilenet_v3_small_crop_orientation( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- MobileNetV3 """ return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -127,14 +124,12 @@ def mobilenet_v3_small_page_orientation( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- MobileNetV3 """ return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/classification/predictor/base.py b/onnxtr/models/classification/predictor/base.py index a0ffdbf..8e0b197 100644 --- a/onnxtr/models/classification/predictor/base.py +++ b/onnxtr/models/classification/predictor/base.py @@ -19,7 +19,6 @@ class OrientationPredictor(NestedObject): 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise. Args: - ---- pre_processor: transform inputs for easier batched model inference model: core classification architecture (backbone + classification head) load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False diff --git a/onnxtr/models/classification/zoo.py b/onnxtr/models/classification/zoo.py index 40644f4..e4c2344 100644 --- a/onnxtr/models/classification/zoo.py +++ b/onnxtr/models/classification/zoo.py @@ -65,7 +65,6 @@ def crop_orientation_predictor( >>> out = model([input_crop]) Args: - ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') batch_size: number of samples the model processes in parallel load_in_8_bit: load the 8-bit quantized version of the model @@ -73,7 +72,6 @@ def crop_orientation_predictor( **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: - ------- OrientationPredictor """ model_type = "crop" @@ -103,7 +101,6 @@ def page_orientation_predictor( >>> out = model([input_page]) Args: - ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') batch_size: number of samples the model processes in parallel load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False @@ -111,7 +108,6 @@ def page_orientation_predictor( **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: - ------- OrientationPredictor """ model_type = "page" diff --git a/onnxtr/models/detection/_utils/base.py b/onnxtr/models/detection/_utils/base.py index fbcfeb5..e6d9a96 100644 --- a/onnxtr/models/detection/_utils/base.py +++ b/onnxtr/models/detection/_utils/base.py @@ -20,8 +20,6 @@ def _remove_padding( """Remove padding from the localization predictions Args: - ---- - pages: list of pages loc_preds: list of localization predictions preserve_aspect_ratio: whether the aspect ratio was preserved during padding @@ -29,7 +27,6 @@ def _remove_padding( assume_straight_pages: whether the pages are assumed to be straight Returns: - ------- list of unpaded localization predictions """ if preserve_aspect_ratio: diff --git a/onnxtr/models/detection/core.py b/onnxtr/models/detection/core.py index ff118db..3d95b20 100644 --- a/onnxtr/models/detection/core.py +++ b/onnxtr/models/detection/core.py @@ -17,7 +17,6 @@ class DetectionPostProcessor(NestedObject): """Abstract class to postprocess the raw output of the model Args: - ---- box_thresh (float): minimal objectness score to consider a box bin_thresh (float): threshold to apply to segmentation raw heatmap assume straight_pages (bool): if True, fit straight boxes only @@ -37,13 +36,11 @@ def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool """Compute the confidence score for a polygon : mean of the p values on the polygon Args: - ---- pred (np.ndarray): p map returned by the model points: coordinates of the polygon assume_straight_pages: if True, fit straight boxes only Returns: - ------- polygon objectness """ h, w = pred.shape[:2] @@ -75,13 +72,11 @@ def __call__( """Performs postprocessing for a list of model outputs Args: - ---- proba_map: probability map of shape (N, H, W, C) Returns: - ------- list of N class predictions (for each input sample), where each class predictions is a list of C tensors - of shape (*, 5) or (*, 6) + of shape (*, 5) or (*, 6) """ if proba_map.ndim != 4: raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.") diff --git a/onnxtr/models/detection/models/differentiable_binarization.py b/onnxtr/models/detection/models/differentiable_binarization.py index 1cfd82d..a515e4b 100644 --- a/onnxtr/models/detection/models/differentiable_binarization.py +++ b/onnxtr/models/detection/models/differentiable_binarization.py @@ -43,7 +43,6 @@ class DBNet(Engine): """DBNet Onnx loader Args: - ---- model_path: path or url to onnx model file engine_cfg: configuration for the inference engine bin_thresh: threshold for binarization of the output feature map @@ -120,14 +119,12 @@ def db_resnet34( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -149,14 +146,12 @@ def db_resnet50( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -178,14 +173,12 @@ def db_mobilenet_v3_large( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/detection/models/fast.py b/onnxtr/models/detection/models/fast.py index bbefeb0..b6518d1 100644 --- a/onnxtr/models/detection/models/fast.py +++ b/onnxtr/models/detection/models/fast.py @@ -41,7 +41,6 @@ class FAST(Engine): """FAST Onnx loader Args: - ---- model_path: path or url to onnx model file engine_cfg: configuration for the inference engine bin_thresh: threshold for binarization of the output feature map @@ -118,14 +117,12 @@ def fast_tiny( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _fast("fast_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -147,14 +144,12 @@ def fast_small( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _fast("fast_small", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -176,14 +171,12 @@ def fast_base( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _fast("fast_base", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/detection/models/linknet.py b/onnxtr/models/detection/models/linknet.py index b26b6a6..902151b 100644 --- a/onnxtr/models/detection/models/linknet.py +++ b/onnxtr/models/detection/models/linknet.py @@ -43,7 +43,6 @@ class LinkNet(Engine): """LinkNet Onnx loader Args: - ---- model_path: path or url to onnx model file engine_cfg: configuration for the inference engine bin_thresh: threshold for binarization of the output feature map @@ -120,14 +119,12 @@ def linknet_resnet18( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the LinkNet architecture Returns: - ------- text detection architecture """ return _linknet("linknet_resnet18", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -149,14 +146,12 @@ def linknet_resnet34( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the LinkNet architecture Returns: - ------- text detection architecture """ return _linknet("linknet_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -178,14 +173,12 @@ def linknet_resnet50( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the LinkNet architecture Returns: - ------- text detection architecture """ return _linknet("linknet_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/detection/postprocessor/base.py b/onnxtr/models/detection/postprocessor/base.py index db89cfc..f9d4f0c 100644 --- a/onnxtr/models/detection/postprocessor/base.py +++ b/onnxtr/models/detection/postprocessor/base.py @@ -21,7 +21,6 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor): """Implements a post processor for FAST model. Args: - ---- bin_thresh: threshold used to binzarized p_map at inference time box_thresh: minimal objectness score to consider a box assume_straight_pages: whether the inputs were expected to have horizontal text elements @@ -43,11 +42,9 @@ def polygon_to_box( """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon Args: - ---- points: The first parameter. Returns: - ------- a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) """ if not self.assume_straight_pages: @@ -92,16 +89,14 @@ def bitmap_to_boxes( """Compute boxes from a bitmap/pred_map: find connected components then filter boxes Args: - ---- pred: Pred map from differentiable linknet output bitmap: Bitmap map computed from pred (binarized) angle_tol: Comparison tolerance of the angle with the median angle across the page ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop Returns: - ------- np tensor boxes for the bitmap, each box is a 6-element list - containing x, y, w, h, alpha, score for the box + containing x, y, w, h, alpha, score for the box """ height, width = bitmap.shape[:2] boxes: List[Union[np.ndarray, List[float]]] = [] diff --git a/onnxtr/models/detection/predictor/base.py b/onnxtr/models/detection/predictor/base.py index 73debce..bed64ee 100644 --- a/onnxtr/models/detection/predictor/base.py +++ b/onnxtr/models/detection/predictor/base.py @@ -18,7 +18,6 @@ class DetectionPredictor(NestedObject): """Implements an object able to localize text elements in a document Args: - ---- pre_processor: transform inputs for easier batched model inference model: core detection architecture """ diff --git a/onnxtr/models/detection/zoo.py b/onnxtr/models/detection/zoo.py index 38e7e56..89d688e 100644 --- a/onnxtr/models/detection/zoo.py +++ b/onnxtr/models/detection/zoo.py @@ -75,7 +75,6 @@ def detection_predictor( >>> out = model([input_page]) Args: - ---- arch: name of the architecture or model itself to use (e.g. 'db_resnet50') assume_straight_pages: If True, fit straight boxes to the page preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before @@ -87,7 +86,6 @@ def detection_predictor( **kwargs: optional keyword arguments passed to the architecture Returns: - ------- Detection predictor """ return _predictor( diff --git a/onnxtr/models/engine.py b/onnxtr/models/engine.py index 607cb8a..8536d82 100644 --- a/onnxtr/models/engine.py +++ b/onnxtr/models/engine.py @@ -25,7 +25,6 @@ class EngineConfig: """Implements a configuration class for the engine of a model Args: - ---- providers: list of providers to use for inference ref.: https://onnxruntime.ai/docs/execution-providers/ session_options: configuration for the inference session ref.: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions """ @@ -81,7 +80,6 @@ class Engine: """Implements an abstract class for the engine of a model Args: - ---- url: the url to use to download a model if needed engine_cfg: the configuration of the engine **kwargs: additional arguments to be passed to `download_from_url` diff --git a/onnxtr/models/factory/hub.py b/onnxtr/models/factory/hub.py index 94fd08b..a4bdfe8 100644 --- a/onnxtr/models/factory/hub.py +++ b/onnxtr/models/factory/hub.py @@ -59,7 +59,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task """Save model and config to disk for pushing to huggingface hub Args: - ---- model: Onnx model to be saved save_dir: directory to save model and config arch: architecture name @@ -91,7 +90,6 @@ def push_to_hf_hub( >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small') Args: - ---- model: Onnx model to be saved model_name: name of the model which is also the repository name task: task name @@ -186,13 +184,11 @@ def from_hub(repo_id: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: >>> model = from_hub("onnxtr/my-model") Args: - ---- repo_id: HuggingFace model hub repo engine_cfg: configuration for the inference engine (optional) - kwargs: kwargs of `hf_hub_download` + **kwargs: kwargs of `hf_hub_download` Returns: - ------- Model loaded with the checkpoint """ # Get the config diff --git a/onnxtr/models/predictor/base.py b/onnxtr/models/predictor/base.py index 613e1ad..a677fe0 100644 --- a/onnxtr/models/predictor/base.py +++ b/onnxtr/models/predictor/base.py @@ -24,7 +24,6 @@ class _OCRPredictor: """Implements an object able to localize and identify text elements in a set of documents Args: - ---- assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. straighten_pages: if True, estimates the page general orientation based on the median line orientation. @@ -205,7 +204,6 @@ def add_hook(self, hook: Callable) -> None: """Add a hook to the predictor Args: - ---- hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds` """ self.hooks.append(hook) diff --git a/onnxtr/models/predictor/predictor.py b/onnxtr/models/predictor/predictor.py index fe43931..7addfeb 100644 --- a/onnxtr/models/predictor/predictor.py +++ b/onnxtr/models/predictor/predictor.py @@ -24,7 +24,6 @@ class OCRPredictor(NestedObject, _OCRPredictor): """Implements an object able to localize and identify text elements in a set of documents Args: - ---- det_predictor: detection module reco_predictor: recognition module assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages diff --git a/onnxtr/models/preprocessor/base.py b/onnxtr/models/preprocessor/base.py index a1c31d8..ce50b16 100644 --- a/onnxtr/models/preprocessor/base.py +++ b/onnxtr/models/preprocessor/base.py @@ -20,7 +20,6 @@ class PreProcessor(NestedObject): """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. Args: - ---- output_size: expected size of each page in format (H, W) batch_size: the size of page batches mean: mean value of the training distribution by channel @@ -46,11 +45,9 @@ def batch_inputs(self, samples: List[np.ndarray]) -> List[np.ndarray]: """Gather samples into batches for inference purposes Args: - ---- samples: list of samples (tf.Tensor) Returns: - ------- list of batched samples """ num_batches = int(math.ceil(len(samples) / self.batch_size)) @@ -81,11 +78,9 @@ def __call__(self, x: Union[np.ndarray, List[np.ndarray]]) -> List[np.ndarray]: """Prepare document data for model forwarding Args: - ---- x: list of images (np.array) or tensors (already resized and batched) Returns: - ------- list of page batches """ # Input type check diff --git a/onnxtr/models/recognition/core.py b/onnxtr/models/recognition/core.py index 35e1c86..eb4a805 100644 --- a/onnxtr/models/recognition/core.py +++ b/onnxtr/models/recognition/core.py @@ -13,7 +13,6 @@ class RecognitionPostProcessor(NestedObject): """Abstract class to postprocess the raw output of the model Args: - ---- vocab: string containing the ordered sequence of supported characters """ diff --git a/onnxtr/models/recognition/models/crnn.py b/onnxtr/models/recognition/models/crnn.py index dffd56b..ed1806b 100644 --- a/onnxtr/models/recognition/models/crnn.py +++ b/onnxtr/models/recognition/models/crnn.py @@ -49,7 +49,6 @@ class CRNNPostProcessor(RecognitionPostProcessor): """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -69,13 +68,11 @@ def ctc_best_path( `_. Args: - ---- logits: model output, shape: N x T x C vocab: vocabulary to use blank: index of blank label Returns: - ------- A list of tuples: (word, confidence) """ # Gather the most confident characters, and assign the smallest conf among those to the sequence prob @@ -94,11 +91,9 @@ def __call__(self, logits): with label_to_idx mapping dictionnary Args: - ---- logits: raw output of the model, shape (N, C + 1, seq_len) Returns: - ------- A tuple of 2 lists: a list of str (words) and a list of float (probs) """ @@ -110,7 +105,6 @@ class CRNN(Engine): """CRNN Onnx loader Args: - ---- model_path: path or url to onnx model file vocab: vocabulary used for encoding engine_cfg: configuration for the inference engine @@ -187,14 +181,12 @@ def crnn_vgg16_bn( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the CRNN architecture Returns: - ------- text recognition architecture """ return _crnn("crnn_vgg16_bn", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -216,14 +208,12 @@ def crnn_mobilenet_v3_small( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the CRNN architecture Returns: - ------- text recognition architecture """ return _crnn("crnn_mobilenet_v3_small", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -245,14 +235,12 @@ def crnn_mobilenet_v3_large( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the CRNN architecture Returns: - ------- text recognition architecture """ return _crnn("crnn_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/models/master.py b/onnxtr/models/recognition/models/master.py index c42c9fc..f3056ac 100644 --- a/onnxtr/models/recognition/models/master.py +++ b/onnxtr/models/recognition/models/master.py @@ -33,7 +33,6 @@ class MASTER(Engine): """MASTER Onnx loader Args: - ---- model_path: path or url to onnx model file vocab: vocabulary, (without EOS, SOS, PAD) engine_cfg: configuration for the inference engine @@ -64,12 +63,10 @@ def __call__( """Call function Args: - ---- x: images return_model_output: if True, return logits Returns: - ------- A dictionnary containing eventually logits and predictions. """ logits = self.run(x) @@ -87,7 +84,6 @@ class MASTERPostProcessor(RecognitionPostProcessor): """Post-processor for the MASTER model Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -147,14 +143,12 @@ def master( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keywoard arguments passed to the MASTER architecture Returns: - ------- text recognition architecture """ return _master("master", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/models/parseq.py b/onnxtr/models/recognition/models/parseq.py index 868edba..dc28603 100644 --- a/onnxtr/models/recognition/models/parseq.py +++ b/onnxtr/models/recognition/models/parseq.py @@ -32,7 +32,6 @@ class PARSeq(Engine): """PARSeq Onnx loader Args: - ---- model_path: path to onnx model file vocab: vocabulary used for encoding engine_cfg: configuration for the inference engine @@ -74,7 +73,6 @@ class PARSeqPostProcessor(RecognitionPostProcessor): """Post processor for PARSeq architecture Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -138,14 +136,12 @@ def parseq( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the PARSeq architecture Returns: - ------- text recognition architecture """ return _parseq("parseq", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/models/sar.py b/onnxtr/models/recognition/models/sar.py index 4758dd4..e65ccd1 100644 --- a/onnxtr/models/recognition/models/sar.py +++ b/onnxtr/models/recognition/models/sar.py @@ -32,7 +32,6 @@ class SAR(Engine): """SAR Onnx loader Args: - ---- model_path: path to onnx model file vocab: vocabulary used for encoding engine_cfg: configuration for the inference engine @@ -75,7 +74,6 @@ class SARPostProcessor(RecognitionPostProcessor): """Post processor for SAR architectures Args: - ---- embedding: string containing the ordered sequence of supported characters """ @@ -137,14 +135,12 @@ def sar_resnet31( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the SAR architecture Returns: - ------- text recognition architecture """ return _sar("sar_resnet31", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/models/vitstr.py b/onnxtr/models/recognition/models/vitstr.py index 4677439..3833cf2 100644 --- a/onnxtr/models/recognition/models/vitstr.py +++ b/onnxtr/models/recognition/models/vitstr.py @@ -40,7 +40,6 @@ class ViTSTR(Engine): """ViTSTR Onnx loader Args: - ---- model_path: path to onnx model file vocab: vocabulary used for encoding engine_cfg: configuration for the inference engine @@ -83,7 +82,6 @@ class ViTSTRPostProcessor(RecognitionPostProcessor): """Post processor for ViTSTR architecture Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -147,14 +145,12 @@ def vitstr_small( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the ViTSTR architecture Returns: - ------- text recognition architecture """ return _vitstr("vitstr_small", model_path, load_in_8_bit, engine_cfg, **kwargs) @@ -176,14 +172,12 @@ def vitstr_base( >>> out = model(input_tensor) Args: - ---- model_path: path to onnx model file, defaults to url in default_cfgs load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments of the ViTSTR architecture Returns: - ------- text recognition architecture """ return _vitstr("vitstr_base", model_path, load_in_8_bit, engine_cfg, **kwargs) diff --git a/onnxtr/models/recognition/predictor/_utils.py b/onnxtr/models/recognition/predictor/_utils.py index 4998556..7397c6d 100644 --- a/onnxtr/models/recognition/predictor/_utils.py +++ b/onnxtr/models/recognition/predictor/_utils.py @@ -22,7 +22,6 @@ def split_crops( """Chunk crops horizontally to match a given aspect ratio Args: - ---- crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise max_ratio: the maximum aspect ratio that won't trigger the chunk target_ratio: when crops are chunked, they will be chunked to match this aspect ratio @@ -30,7 +29,6 @@ def split_crops( channels_last: whether the numpy array has dimensions in channels last order Returns: - ------- a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required """ _remap_required = False diff --git a/onnxtr/models/recognition/predictor/base.py b/onnxtr/models/recognition/predictor/base.py index f4c7ce8..c038459 100644 --- a/onnxtr/models/recognition/predictor/base.py +++ b/onnxtr/models/recognition/predictor/base.py @@ -19,7 +19,6 @@ class RecognitionPredictor(NestedObject): """Implements an object able to identify character sequences in images Args: - ---- pre_processor: transform inputs for easier batched model inference model: core recognition architecture split_wide_crops: wether to use crop splitting for high aspect ratio crops diff --git a/onnxtr/models/recognition/utils.py b/onnxtr/models/recognition/utils.py index 99b6307..e33ec57 100644 --- a/onnxtr/models/recognition/utils.py +++ b/onnxtr/models/recognition/utils.py @@ -14,14 +14,12 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str: """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters. Args: - ---- a: first char seq, suffix should be similar to b's prefix. b: second char seq, prefix should be similar to a's suffix. dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is only used when the mother sequence is splitted on a character repetition Returns: - ------- A merged character sequence. Example:: @@ -65,13 +63,11 @@ def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str: """Recursively merges consecutive string sequences with overlapping characters. Args: - ---- seq_list: list of sequences to merge. Sequences need to be ordered from left to right. dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is only used when the mother sequence is splitted on a character repetition Returns: - ------- A merged character sequence Example:: diff --git a/onnxtr/models/recognition/zoo.py b/onnxtr/models/recognition/zoo.py index 8eb8bd3..d8165ff 100644 --- a/onnxtr/models/recognition/zoo.py +++ b/onnxtr/models/recognition/zoo.py @@ -67,7 +67,6 @@ def recognition_predictor( >>> out = model([input_page]) Args: - ---- arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right batch_size: number of samples the model processes in parallel @@ -76,7 +75,6 @@ def recognition_predictor( **kwargs: optional parameters to be passed to the architecture Returns: - ------- Recognition predictor """ return _predictor( diff --git a/onnxtr/models/zoo.py b/onnxtr/models/zoo.py index 5700bb8..3836774 100644 --- a/onnxtr/models/zoo.py +++ b/onnxtr/models/zoo.py @@ -88,7 +88,6 @@ def ocr_predictor( >>> out = model([input_page]) Args: - ---- det_arch: name of the detection architecture or the model itself to use (e.g. 'db_resnet50', 'db_mobilenet_v3_large') reco_arch: name of the recognition architecture or the model itself to use @@ -115,7 +114,6 @@ def ocr_predictor( kwargs: keyword args of `OCRPredictor` Returns: - ------- OCR predictor """ return _predictor( diff --git a/onnxtr/transforms/base.py b/onnxtr/transforms/base.py index 2e7203b..5adf433 100644 --- a/onnxtr/transforms/base.py +++ b/onnxtr/transforms/base.py @@ -12,7 +12,14 @@ class Resize: - """Resize the input image to the given size""" + """Resize the input image to the given size + + Args: + size: the target size of the image + interpolation: the interpolation method to use + preserve_aspect_ratio: whether to preserve the aspect ratio of the image + symmetric_pad: whether to symmetrically pad the image + """ def __init__( self, @@ -72,7 +79,12 @@ def __repr__(self) -> str: class Normalize: - """Normalize the input image""" + """Normalize the input image + + Args: + mean: mean values to subtract + std: standard deviation values to divide + """ def __init__( self, diff --git a/onnxtr/utils/data.py b/onnxtr/utils/data.py index 495f2d2..cfe596f 100644 --- a/onnxtr/utils/data.py +++ b/onnxtr/utils/data.py @@ -56,7 +56,6 @@ def download_from_url( >>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip") Args: - ---- url: the URL of the file to download file_name: optional name of the file once downloaded hash_prefix: optional expected SHA256 hash of the file @@ -64,11 +63,9 @@ def download_from_url( cache_subdir: subfolder to use in the cache Returns: - ------- the location of the downloaded file Note: - ---- You can change cache directory location by using `ONNXTR_CACHE_DIR` environment variable. """ if not isinstance(file_name, str): @@ -112,7 +109,7 @@ def download_from_url( except (urllib.error.URLError, IOError) as e: # pragma: no cover if url[:5] == "https": url = url.replace("https:", "http:") - print("Failed download. Trying https -> http instead." f" Downloading {url} to {file_path}") + print(f"Failed download. Trying https -> http instead. Downloading {url} to {file_path}") _urlretrieve(url, file_path) else: raise e diff --git a/onnxtr/utils/fonts.py b/onnxtr/utils/fonts.py index 9e525d9..a44db1b 100644 --- a/onnxtr/utils/fonts.py +++ b/onnxtr/utils/fonts.py @@ -18,12 +18,10 @@ def get_font( """Resolves a compatible ImageFont for the system Args: - ---- font_family: the font family to use font_size: the size of the font upon rendering Returns: - ------- the Pillow font """ # Font selection diff --git a/onnxtr/utils/geometry.py b/onnxtr/utils/geometry.py index 68b9a18..9d2aff2 100644 --- a/onnxtr/utils/geometry.py +++ b/onnxtr/utils/geometry.py @@ -34,11 +34,9 @@ def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P: """Convert a bounding box to a polygon Args: - ---- bbox: a bounding box Returns: - ------- a polygon """ return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1] @@ -48,11 +46,9 @@ def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox: """Convert a polygon to a bounding box Args: - ---- polygon: a polygon Returns: - ------- a bounding box """ x, y = zip(*polygon) @@ -61,11 +57,11 @@ def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox: def detach_scores(boxes: List[np.ndarray]) -> Tuple[List[np.ndarray], List[np.ndarray]]: """Detach the objectness scores from box predictions + Args: - ---- boxes: list of arrays with boxes of shape (N, 5) or (N, 5, 2) + Returns: - ------- a tuple of two lists: the first one contains the boxes without the objectness scores, the second one contains the objectness scores """ @@ -83,12 +79,10 @@ def shape_translate(data: np.ndarray, format: str) -> np.ndarray: """Translate the shape of the input data to the desired format Args: - ---- data: input data in shape (B, C, H, W) or (B, H, W, C) or (C, H, W) or (H, W, C) format: target format ('BCHW', 'BHWC', 'CHW', or 'HWC') Returns: - ------- the reshaped data """ # Get the current shape @@ -124,7 +118,6 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio """Compute enclosing bbox either from: Args: - ---- bboxes: boxes in one of the following formats: - an array of boxes: (*, 4), where boxes have this shape: @@ -133,7 +126,6 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio - a list of BoundingBox Returns: - ------- a (1, 4) array (enclosing boxarray), or a BoundingBox """ if isinstance(bboxes, np.ndarray): @@ -148,7 +140,6 @@ def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024 """Compute enclosing rotated bbox either from: Args: - ---- rbboxes: boxes in one of the following formats: - an array of boxes: (*, 4, 2), where boxes have this shape: @@ -158,7 +149,6 @@ def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024 intermed_size: size of the intermediate image Returns: - ------- a (4, 2) array (enclosing rotated box) """ cloud: np.ndarray = np.concatenate(rbboxes, axis=0) @@ -172,12 +162,10 @@ def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray: """Rotate points counter-clockwise. Args: - ---- points: array of size (N, 2) angle: angle between -90 and +90 degrees Returns: - ------- Rotated points """ angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions @@ -191,12 +179,10 @@ def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[in """Compute the shape of an expanded rotated image Args: - ---- img_shape: the height and width of the image angle: angle between -90 and +90 degrees Returns: - ------- the height and width of the rotated image """ points: np.ndarray = np.array([ @@ -220,14 +206,12 @@ def rotate_abs_geoms( image center. Args: - ---- geoms: (N, 4) or (N, 4, 2) array of ABSOLUTE coordinate boxes angle: anti-clockwise rotation angle in degrees img_shape: the height and width of the image expand: whether the image should be padded to avoid information loss Returns: - ------- A batch of rotated polygons (N, 4, 2) """ # Switch to polygons @@ -259,13 +243,11 @@ def remap_boxes(loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape: coordinates after a resizing of the image. Args: - ---- loc_preds: (N, 4, 2) array of RELATIVE loc_preds orig_shape: shape of the origin image dest_shape: shape of the destination image Returns: - ------- A batch of rotated loc_preds (N, 4, 2) expressed in the destination referencial """ if len(dest_shape) != 2: @@ -294,7 +276,6 @@ def rotate_boxes( is done to remove the padding that is created by rotate_page(expand=True) Args: - ---- loc_preds: (N, 4) or (N, 4, 2) array of RELATIVE boxes angle: angle between -90 and +90 degrees orig_shape: shape of the origin image @@ -302,7 +283,6 @@ def rotate_boxes( target_shape: shape of the destination image Returns: - ------- A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes """ # Change format of the boxes to rotated boxes @@ -349,14 +329,12 @@ def rotate_image( """Rotate an image counterclockwise by an given angle. Args: - ---- image: numpy tensor to rotate angle: rotation angle in degrees, between -90 and +90 expand: whether the image should be padded before the rotation preserve_origin_shape: if expand is set to True, resizes the final output to the original image size Returns: - ------- Rotated array, padded by 0 by default. """ # Compute the expanded padding @@ -395,11 +373,9 @@ def remove_image_padding(image: np.ndarray) -> np.ndarray: """Remove black border padding from an image Args: - ---- image: numpy tensor to remove padding from Returns: - ------- Image with padding removed """ # Find the bounding box of the non-black region @@ -433,12 +409,10 @@ def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) -> """Convert a geometry to relative coordinates Args: - ---- geoms: a set of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4) img_shape: the height and width of the image Returns: - ------- the updated geometry """ # Polygon @@ -460,14 +434,12 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True """Created cropped images from list of bounding boxes Args: - ---- img: input image boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative coordinates (xmin, ymin, xmax, ymax) channels_last: whether the channel dimensions is the last one instead of the last one Returns: - ------- list of cropped images """ if boxes.shape[0] == 0: @@ -496,7 +468,6 @@ def extract_rcrops( """Created cropped images from list of rotated bounding boxes Args: - ---- img: input image polys: bounding boxes of shape (N, 4, 2) dtype: target data type of bounding boxes @@ -504,7 +475,6 @@ def extract_rcrops( assume_horizontal: whether the boxes are assumed to be only horizontally oriented Returns: - ------- list of cropped images """ if polys.shape[0] == 0: diff --git a/onnxtr/utils/multithreading.py b/onnxtr/utils/multithreading.py index adb5adb..0f51b84 100644 --- a/onnxtr/utils/multithreading.py +++ b/onnxtr/utils/multithreading.py @@ -22,17 +22,14 @@ def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Op >>> results = multithread_exec(lambda x: x ** 2, entries) Args: - ---- func: function to be executed on each element of the iterable seq: iterable threads: number of workers to be used for multiprocessing Returns: - ------- iterator of the function's results using the iterable as inputs Notes: - ----- This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory. If you do not have write permissions for this directory (if you run `onnxtr` on AWS Lambda for instance), you might want to disable multiprocessing. To achieve that, set 'ONNXTR_MULTIPROCESSING_DISABLE' to 'TRUE'. diff --git a/onnxtr/utils/reconstitution.py b/onnxtr/utils/reconstitution.py index 790ee86..94c10bc 100644 --- a/onnxtr/utils/reconstitution.py +++ b/onnxtr/utils/reconstitution.py @@ -121,7 +121,6 @@ def synthesize_page( """Draw a the content of the element page (OCR response) on a blank page. Args: - ---- page: exported Page object to represent draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 font_family: family of the font @@ -130,7 +129,6 @@ def synthesize_page( max_font_size: maximum font size Returns: - ------- the synthesized page """ # Draw template diff --git a/onnxtr/utils/visualization.py b/onnxtr/utils/visualization.py index 87fb8af..f6aad2b 100644 --- a/onnxtr/utils/visualization.py +++ b/onnxtr/utils/visualization.py @@ -30,7 +30,6 @@ def rect_patch( """Create a matplotlib rectangular patch for the element Args: - ---- geometry: bounding box of the element page_dimensions: dimensions of the Page in format (height, width) label: label to display when hovered @@ -41,7 +40,6 @@ def rect_patch( preserve_aspect_ratio: pass True if you passed True to the predictor Returns: - ------- a rectangular Patch """ if len(geometry) != 2 or any(not isinstance(elt, tuple) or len(elt) != 2 for elt in geometry): @@ -81,7 +79,6 @@ def polygon_patch( """Create a matplotlib polygon patch for the element Args: - ---- geometry: bounding box of the element page_dimensions: dimensions of the Page in format (height, width) label: label to display when hovered @@ -92,7 +89,6 @@ def polygon_patch( preserve_aspect_ratio: pass True if you passed True to the predictor Returns: - ------- a polygon Patch """ if not geometry.shape == (4, 2): @@ -121,13 +117,11 @@ def create_obj_patch( """Create a matplotlib patch for the element Args: - ---- geometry: bounding box (straight or rotated) of the element page_dimensions: dimensions of the page in format (height, width) **kwargs: keyword arguments for the patch Returns: - ------- a matplotlib Patch """ if isinstance(geometry, tuple): @@ -163,7 +157,6 @@ def visualize_page( >>> plt.show() Args: - ---- page: the exported Page of a Document image: np array of the page, needs to have the same shape than page['dimensions'] words_only: whether only words should be displayed @@ -174,7 +167,6 @@ def visualize_page( **kwargs: keyword arguments for the polygon patch Returns: - ------- the matplotlib figure """ # Get proper scale and aspect ratio @@ -270,7 +262,6 @@ def draw_boxes(boxes: np.ndarray, image: np.ndarray, color: Optional[Tuple[int, """Draw an array of relative straight boxes on an image Args: - ---- boxes: array of relative boxes, of shape (*, 4) image: np array, float32 or uint8 color: color to use for bounding box edges diff --git a/pyproject.toml b/pyproject.toml index 2c51a87..35a29bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,7 +156,7 @@ select = [ "E", "W", "F", "I", "N", "Q", "C4", "T10", "LOG", "D101", "D103", "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419", "D207" # pydocstyle ] -ignore = ["E402", "E203", "F403", "E731", "N812", "N817", "C408"] +ignore = ["E402", "E203", "F403", "E731", "N812", "N817", "C408", "LOG015"] [tool.ruff.lint.isort] known-first-party = ["onnxtr", "utils"] diff --git a/tests/common/test_transforms.py b/tests/common/test_transforms.py index ab642b1..9242d5a 100644 --- a/tests/common/test_transforms.py +++ b/tests/common/test_transforms.py @@ -26,7 +26,7 @@ def test_resize(): # Symetric padding transfo = Resize(output_size, preserve_aspect_ratio=True, symmetric_pad=True) assert repr(transfo) == ( - f"Resize(output_size={output_size}, interpolation='2', " f"preserve_aspect_ratio=True, symmetric_pad=True)" + f"Resize(output_size={output_size}, interpolation='2', preserve_aspect_ratio=True, symmetric_pad=True)" ) out = transfo(input_t) assert out.shape[:2] == output_size