Skip to content

Commit

Permalink
Improve comparison performance
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Nov 27, 2024
1 parent c78f4af commit bd95194
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 21 deletions.
21 changes: 13 additions & 8 deletions marker/builders/layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

import numpy as np
from surya.layout import batch_layout_detection
from surya.schema import LayoutResult
from surya.model.layout.encoderdecoder import SuryaLayoutModel
Expand All @@ -13,6 +14,7 @@
from marker.schema.groups.page import PageGroup
from marker.schema.polygon import PolygonBox
from marker.schema.registry import get_block_class
from marker.util import matrix_intersection_area


class LayoutBuilder(BaseBuilder):
Expand All @@ -35,6 +37,7 @@ class LayoutBuilder(BaseBuilder):
batch_size = None
layout_coverage_min_lines = 1
layout_coverage_threshold = .1
excluded_for_coverage = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)

def __init__(self, layout_model: SuryaLayoutModel, config=None):
self.layout_model = layout_model
Expand Down Expand Up @@ -91,16 +94,18 @@ def check_layout_coverage(
covered_blocks = 0
total_blocks = 0
large_text_blocks = 0
for layout_block_id in document_page.structure:
layout_block = document_page.get_block(layout_block_id)
if layout_block.block_type in [BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup]:
continue

layout_blocks = [document_page.get_block(block) for block in document_page.structure]
layout_blocks = [b for b in layout_blocks if b.block_type not in self.excluded_for_coverage]

layout_bboxes = [block.polygon.bbox for block in layout_blocks]
provider_bboxes = [line.line.polygon.bbox for line in provider_lines]

intersection_matrix = matrix_intersection_area(layout_bboxes, provider_bboxes)

for idx, layout_block in enumerate(layout_blocks):
total_blocks += 1
intersecting_lines = 0
for provider_line in provider_lines:
if layout_block.polygon.intersection_area(provider_line.line.polygon) > 0:
intersecting_lines += 1
intersecting_lines = np.count_nonzero(intersection_matrix[idx] > 0)

if intersecting_lines > self.layout_coverage_min_lines:
covered_blocks += 1
Expand Down
2 changes: 0 additions & 2 deletions marker/output.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os

from ftfy import fix_text
from pydantic import BaseModel

from marker.renderers.html import HTMLOutput
Expand Down Expand Up @@ -30,7 +29,6 @@ def text_from_rendered(rendered: BaseModel):

def save_output(rendered: BaseModel, output_dir: str, fname_base: str):
text, ext, images = text_from_rendered(rendered)
text = fix_text(text)

with open(os.path.join(output_dir, f"{fname_base}.{ext}"), "w+") as f:
f.write(text)
Expand Down
5 changes: 3 additions & 2 deletions marker/providers/pdf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import atexit
import functools
import re
from typing import List, Set
from concurrent.futures.thread import ThreadPoolExecutor
from itertools import repeat
from typing import List, Set, Dict

import pypdfium2 as pdfium
from pdftext.extraction import dictionary_output
Expand Down Expand Up @@ -187,7 +189,6 @@ def detect_bad_ocr(self, text):

return False

@functools.lru_cache(maxsize=None)
def get_image(self, idx: int, dpi: int) -> Image.Image:
page = self.doc[idx]
image = page.render(scale=dpi / 72, draw_annots=False).to_pil()
Expand Down
25 changes: 17 additions & 8 deletions marker/schema/groups/page.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from collections import defaultdict
from typing import Dict, List, TYPE_CHECKING, Sequence, Tuple

import numpy as np
from PIL import Image

from marker.providers import ProviderOutput
from marker.schema import BlockTypes
from marker.schema.blocks import Block, BlockId, Text
from marker.schema.groups.base import Group
from marker.schema.polygon import PolygonBox
from marker.util import matrix_intersection_area

LINE_MAPPING_TYPE = List[Tuple[int, ProviderOutput]]

Expand Down Expand Up @@ -75,15 +77,22 @@ def assemble_html(self, child_blocks, parent_structure=None):
def compute_line_block_intersections(self, provider_outputs: List[ProviderOutput]):
max_intersections = {}

blocks = [
block for block in self.children
if block.block_type not in self.excluded_block_types
]
block_bboxes = [block.polygon.bbox for block in blocks]
line_bboxes = [provider_output.line.polygon.bbox for provider_output in provider_outputs]

intersection_matrix = matrix_intersection_area(line_bboxes, block_bboxes)

for line_idx, line in enumerate(provider_outputs):
for block in self.children:
if block.block_type in self.excluded_block_types:
continue
intersection_pct = line.line.polygon.intersection_pct(block.polygon)
if line_idx not in max_intersections:
max_intersections[line_idx] = (intersection_pct, block.id)
elif intersection_pct > max_intersections[line_idx][0]:
max_intersections[line_idx] = (intersection_pct, block.id)
max_intersection = intersection_matrix[line_idx].argmax()
if intersection_matrix[line_idx, max_intersection] > 0:
max_intersections[line_idx] = (
intersection_matrix[line_idx, max_intersection],
blocks[max_intersection].id
)
return max_intersections

def replace_block(self, block: Block, new_block: Block):
Expand Down
1 change: 1 addition & 0 deletions marker/schema/polygon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
from typing import List

import numpy as np
from pydantic import BaseModel, field_validator, computed_field


Expand Down
4 changes: 3 additions & 1 deletion marker/schema/text/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import re
from typing import List, Literal

from ftfy import fix_text

from marker.schema import BlockTypes
from marker.schema.blocks import Block

Expand Down Expand Up @@ -36,7 +38,7 @@ def assemble_html(self, child_blocks, parent_structure):
if self.ignore_for_output:
return ""

text = self.text
text = fix_text(self.text)

# Remove trailing newlines
replaced_newline = False
Expand Down
19 changes: 19 additions & 0 deletions marker/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from importlib import import_module
from typing import List

import numpy as np
from pydantic import BaseModel


Expand Down Expand Up @@ -57,3 +58,21 @@ def parse_range_str(range_str: str) -> List[int]:
page_lst.append(int(i))
page_lst = sorted(list(set(page_lst))) # Deduplicate page numbers and sort in order
return page_lst


def matrix_intersection_area(boxes1: List[List[float]], boxes2: List[List[float]]) -> np.ndarray:
boxes1 = np.array(boxes1)
boxes2 = np.array(boxes2)

boxes1 = boxes1[:, np.newaxis, :] # Shape: (N, 1, 4)
boxes2 = boxes2[np.newaxis, :, :] # Shape: (1, M, 4)

min_x = np.maximum(boxes1[..., 0], boxes2[..., 0]) # Shape: (N, M)
min_y = np.maximum(boxes1[..., 1], boxes2[..., 1])
max_x = np.minimum(boxes1[..., 2], boxes2[..., 2])
max_y = np.minimum(boxes1[..., 3], boxes2[..., 3])

width = np.maximum(0, max_x - min_x)
height = np.maximum(0, max_y - min_y)

return width * height # Shape: (N, M)

0 comments on commit bd95194

Please sign in to comment.