diff --git a/marker/builders/layout.py b/marker/builders/layout.py index 0ad89052..657e4bde 100644 --- a/marker/builders/layout.py +++ b/marker/builders/layout.py @@ -1,5 +1,6 @@ from typing import List +import numpy as np from surya.layout import batch_layout_detection from surya.schema import LayoutResult from surya.model.layout.encoderdecoder import SuryaLayoutModel @@ -13,6 +14,7 @@ from marker.schema.groups.page import PageGroup from marker.schema.polygon import PolygonBox from marker.schema.registry import get_block_class +from marker.util import matrix_intersection_area class LayoutBuilder(BaseBuilder): @@ -35,6 +37,7 @@ class LayoutBuilder(BaseBuilder): batch_size = None layout_coverage_min_lines = 1 layout_coverage_threshold = .1 + excluded_for_coverage = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup) def __init__(self, layout_model: SuryaLayoutModel, config=None): self.layout_model = layout_model @@ -91,16 +94,18 @@ def check_layout_coverage( covered_blocks = 0 total_blocks = 0 large_text_blocks = 0 - for layout_block_id in document_page.structure: - layout_block = document_page.get_block(layout_block_id) - if layout_block.block_type in [BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup]: - continue + layout_blocks = [document_page.get_block(block) for block in document_page.structure] + layout_blocks = [b for b in layout_blocks if b.block_type not in self.excluded_for_coverage] + + layout_bboxes = [block.polygon.bbox for block in layout_blocks] + provider_bboxes = [line.line.polygon.bbox for line in provider_lines] + + intersection_matrix = matrix_intersection_area(layout_bboxes, provider_bboxes) + + for idx, layout_block in enumerate(layout_blocks): total_blocks += 1 - intersecting_lines = 0 - for provider_line in provider_lines: - if layout_block.polygon.intersection_area(provider_line.line.polygon) > 0: - intersecting_lines += 1 + intersecting_lines = np.count_nonzero(intersection_matrix[idx] > 0) if intersecting_lines > self.layout_coverage_min_lines: covered_blocks += 1 diff --git a/marker/output.py b/marker/output.py index 84504c50..e47c861f 100644 --- a/marker/output.py +++ b/marker/output.py @@ -1,7 +1,6 @@ import json import os -from ftfy import fix_text from pydantic import BaseModel from marker.renderers.html import HTMLOutput @@ -30,7 +29,6 @@ def text_from_rendered(rendered: BaseModel): def save_output(rendered: BaseModel, output_dir: str, fname_base: str): text, ext, images = text_from_rendered(rendered) - text = fix_text(text) with open(os.path.join(output_dir, f"{fname_base}.{ext}"), "w+") as f: f.write(text) diff --git a/marker/providers/pdf.py b/marker/providers/pdf.py index a1035eec..fd416d1e 100644 --- a/marker/providers/pdf.py +++ b/marker/providers/pdf.py @@ -1,7 +1,9 @@ import atexit import functools import re -from typing import List, Set +from concurrent.futures.thread import ThreadPoolExecutor +from itertools import repeat +from typing import List, Set, Dict import pypdfium2 as pdfium from pdftext.extraction import dictionary_output @@ -187,7 +189,6 @@ def detect_bad_ocr(self, text): return False - @functools.lru_cache(maxsize=None) def get_image(self, idx: int, dpi: int) -> Image.Image: page = self.doc[idx] image = page.render(scale=dpi / 72, draw_annots=False).to_pil() diff --git a/marker/schema/groups/page.py b/marker/schema/groups/page.py index 8e1b63c5..fdd638d2 100644 --- a/marker/schema/groups/page.py +++ b/marker/schema/groups/page.py @@ -1,6 +1,7 @@ from collections import defaultdict from typing import Dict, List, TYPE_CHECKING, Sequence, Tuple +import numpy as np from PIL import Image from marker.providers import ProviderOutput @@ -8,6 +9,7 @@ from marker.schema.blocks import Block, BlockId, Text from marker.schema.groups.base import Group from marker.schema.polygon import PolygonBox +from marker.util import matrix_intersection_area LINE_MAPPING_TYPE = List[Tuple[int, ProviderOutput]] @@ -75,15 +77,22 @@ def assemble_html(self, child_blocks, parent_structure=None): def compute_line_block_intersections(self, provider_outputs: List[ProviderOutput]): max_intersections = {} + blocks = [ + block for block in self.children + if block.block_type not in self.excluded_block_types + ] + block_bboxes = [block.polygon.bbox for block in blocks] + line_bboxes = [provider_output.line.polygon.bbox for provider_output in provider_outputs] + + intersection_matrix = matrix_intersection_area(line_bboxes, block_bboxes) + for line_idx, line in enumerate(provider_outputs): - for block in self.children: - if block.block_type in self.excluded_block_types: - continue - intersection_pct = line.line.polygon.intersection_pct(block.polygon) - if line_idx not in max_intersections: - max_intersections[line_idx] = (intersection_pct, block.id) - elif intersection_pct > max_intersections[line_idx][0]: - max_intersections[line_idx] = (intersection_pct, block.id) + max_intersection = intersection_matrix[line_idx].argmax() + if intersection_matrix[line_idx, max_intersection] > 0: + max_intersections[line_idx] = ( + intersection_matrix[line_idx, max_intersection], + blocks[max_intersection].id + ) return max_intersections def replace_block(self, block: Block, new_block: Block): diff --git a/marker/schema/polygon.py b/marker/schema/polygon.py index 94440c10..0cd53074 100644 --- a/marker/schema/polygon.py +++ b/marker/schema/polygon.py @@ -2,6 +2,7 @@ import copy from typing import List +import numpy as np from pydantic import BaseModel, field_validator, computed_field diff --git a/marker/schema/text/span.py b/marker/schema/text/span.py index d5b9fe8d..cf9b77ca 100644 --- a/marker/schema/text/span.py +++ b/marker/schema/text/span.py @@ -2,6 +2,8 @@ import re from typing import List, Literal +from ftfy import fix_text + from marker.schema import BlockTypes from marker.schema.blocks import Block @@ -36,7 +38,7 @@ def assemble_html(self, child_blocks, parent_structure): if self.ignore_for_output: return "" - text = self.text + text = fix_text(self.text) # Remove trailing newlines replaced_newline = False diff --git a/marker/util.py b/marker/util.py index df98eefe..6f11b04d 100644 --- a/marker/util.py +++ b/marker/util.py @@ -2,6 +2,7 @@ from importlib import import_module from typing import List +import numpy as np from pydantic import BaseModel @@ -57,3 +58,21 @@ def parse_range_str(range_str: str) -> List[int]: page_lst.append(int(i)) page_lst = sorted(list(set(page_lst))) # Deduplicate page numbers and sort in order return page_lst + + +def matrix_intersection_area(boxes1: List[List[float]], boxes2: List[List[float]]) -> np.ndarray: + boxes1 = np.array(boxes1) + boxes2 = np.array(boxes2) + + boxes1 = boxes1[:, np.newaxis, :] # Shape: (N, 1, 4) + boxes2 = boxes2[np.newaxis, :, :] # Shape: (1, M, 4) + + min_x = np.maximum(boxes1[..., 0], boxes2[..., 0]) # Shape: (N, M) + min_y = np.maximum(boxes1[..., 1], boxes2[..., 1]) + max_x = np.minimum(boxes1[..., 2], boxes2[..., 2]) + max_y = np.minimum(boxes1[..., 3], boxes2[..., 3]) + + width = np.maximum(0, max_x - min_x) + height = np.maximum(0, max_y - min_y) + + return width * height # Shape: (N, M)