Skip to content

Commit

Permalink
Merge pull request #386 from VikParuchuri/dev-mose/marker-v2
Browse files Browse the repository at this point in the history
Misc Bugfixes
  • Loading branch information
VikParuchuri authored Nov 26, 2024
2 parents fc34530 + 26328ff commit 96d1b81
Show file tree
Hide file tree
Showing 24 changed files with 136 additions and 83 deletions.
20 changes: 9 additions & 11 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
os.environ["IN_STREAMLIT"] = "true" # Avoid multiprocessing inside surya

import argparse
import math
import traceback

import click
import torch.multiprocessing as mp
from tqdm import tqdm
import math

from marker.config.parser import ConfigParser
from marker.converters.pdf import PdfConverter
from marker.logger import configure_logging
from marker.models import create_model_dict
from marker.output import save_output, output_exists
from marker.config.parser import ConfigParser
from marker.models import create_model_dict
from marker.output import output_exists, save_output
from marker.settings import settings
from marker.logger import configure_logging
import traceback
import json
import click

configure_logging()

Expand All @@ -42,7 +39,7 @@ def process_single_pdf(args):

out_folder = config_parser.get_output_folder(fpath)
base_name = config_parser.get_base_filename(fpath)
if output_exists(out_folder, base_name):
if cli_options.get('skip_existing') and output_exists(out_folder, base_name):
return

try:
Expand All @@ -66,7 +63,8 @@ def process_single_pdf(args):
@click.option("--chunk_idx", type=int, default=0, help="Chunk index to convert")
@click.option("--num_chunks", type=int, default=1, help="Number of chunks being processed in parallel")
@click.option("--max_files", type=int, default=None, help="Maximum number of pdfs to convert")
@click.option("--workers", type=int, default=3, help="Number of worker processes to use.")
@click.option("--workers", type=int, default=5, help="Number of worker processes to use.")
@click.option("--skip_existing", is_flag=True, default=False, help="Skip existing converted files.")
def main(in_folder: str, **kwargs):
in_folder = os.path.abspath(in_folder)
files = [os.path.join(in_folder, f) for f in os.listdir(in_folder)]
Expand Down
5 changes: 3 additions & 2 deletions marker/builders/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_recognition_batch_size(self):
if self.recognition_batch_size is not None:
return self.recognition_batch_size
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 128
return 32
elif settings.TORCH_DEVICE_MODEL == "mps":
return 32
return 32
Expand All @@ -71,7 +71,8 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderP
det_processor=self.detection_model.processor,
rec_model=self.recognition_model,
rec_processor=self.recognition_model.processor,
batch_size=int(self.get_recognition_batch_size()),
detection_batch_size=int(self.get_detection_batch_size()),
recognition_batch_size=int(self.get_recognition_batch_size()),
highres_images=[page.highres_image for page in page_list]
)

Expand Down
40 changes: 18 additions & 22 deletions marker/builders/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,28 @@ def group_caption_blocks(self, page: PageGroup):
if block.block_type not in [BlockTypes.Table, BlockTypes.Figure, BlockTypes.Picture]:
continue

if block.block_id in remove_ids:
if block.id in remove_ids:
continue

block_structure = [block_id]
selected_polygons = [block.polygon]
for j, prev_block_id in enumerate(page.structure[:i][::-1]):
prev_block = page.get_block(prev_block_id)
if all([
prev_block.block_type in [BlockTypes.Caption, BlockTypes.Footnote],
prev_block.polygon.minimum_gap(block.polygon) < gap_threshold_px
]):
block_structure.insert(0, prev_block_id)
selected_polygons.append(selected_polygons[0])
else:
break

for j, next_block_id in enumerate(page.structure[i + 1:]):
next_block = page.get_block(next_block_id)
if all([
next_block.block_type in [BlockTypes.Caption, BlockTypes.Footnote],
next_block.polygon.minimum_gap(selected_polygons[-1]) < gap_threshold_px
]):
block_structure.append(next_block_id)
selected_polygons.append(next_block.polygon)
else:
break
caption_types = [BlockTypes.Caption, BlockTypes.Footnote]

prev_block = page.get_prev_block(block)
next_block = page.get_next_block(block)

if prev_block and \
prev_block.block_type in caption_types and \
prev_block.polygon.minimum_gap(block.polygon) < gap_threshold_px and \
prev_block.id not in remove_ids:
block_structure.insert(0, prev_block.id)
selected_polygons.append(prev_block.polygon)

if next_block and \
next_block.block_type in caption_types and \
next_block.polygon.minimum_gap(block.polygon) < gap_threshold_px:
block_structure.append(next_block.id)
selected_polygons.append(next_block.polygon)

if len(block_structure) > 1:
# Create a merged block
Expand Down
1 change: 0 additions & 1 deletion marker/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def common_options(fn):
help="Path to JSON file with additional configuration.")(fn)
fn = click.option("--languages", type=str, default=None, help="Comma separated list of languages to use for OCR.")(fn)
fn = click.option("--disable_multiprocessing", is_flag=True, default=False, help="Disable multiprocessing.")(fn)
fn = click.option('-l', is_flag=True, help="List available builders, processors and converters")(fn)
return fn

def generate_config_dict(self) -> Dict[str, any]:
Expand Down
9 changes: 6 additions & 3 deletions marker/output.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import json
import os

from ftfy import fix_text
from pydantic import BaseModel
from marker.renderers.markdown import MarkdownOutput

from marker.renderers.html import HTMLOutput
from marker.renderers.json import JSONOutput
from marker.renderers.markdown import MarkdownOutput


def output_exists(output_dir: str, fname_base: str):
Expand All @@ -28,11 +30,12 @@ def text_from_rendered(rendered: BaseModel):

def save_output(rendered: BaseModel, output_dir: str, fname_base: str):
text, ext, images = text_from_rendered(rendered)
text = fix_text(text)

with open(os.path.join(output_dir, f"{fname_base}.{ext}"), "w+") as f:
f.write(text)
with open(os.path.join(output_dir, f"{fname_base}_meta.json"), "w+") as f:
f.write(json.dumps(rendered.metadata, indent=2))

for img_name, img in images.items():
img.save(os.path.join(output_dir, img_name), "PNG")
img.save(os.path.join(output_dir, img_name), "PNG")
4 changes: 4 additions & 0 deletions marker/processors/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def format_block(self, document: Document, block: Code):
min_left = 9999 # will contain x- coord of column 0
total_width = 0
total_chars = 0

if block.structure is None:
return

for line_id in block.structure:
line = document.get_block(line_id)
min_left = min(line.polygon.bbox[0], min_left)
Expand Down
46 changes: 30 additions & 16 deletions marker/processors/ignoretext.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import re
from collections import Counter
from itertools import groupby
from typing import List

from rapidfuzz import fuzz

from marker.processors import BaseProcessor
from marker.schema import BlockTypes
from marker.schema.blocks import Block
from marker.schema.document import Document

from rapidfuzz import fuzz


class IgnoreTextProcessor(BaseProcessor):
"""
Expand All @@ -17,10 +20,13 @@ class IgnoreTextProcessor(BaseProcessor):
The minimum fraction of pages that a block must appear in to be considered a common element.
Default is 0.6.
"""
block_types = (BlockTypes.Text,)
common_element_threshold = .25
block_types = (
BlockTypes.Text, BlockTypes.PageHeader,
BlockTypes.PageFooter, BlockTypes.SectionHeader
)
common_element_threshold = .20
common_element_min_blocks = 3
max_blocks = 1
max_streak = 3 # The maximum number of blocks in a row to consider a common element
text_match_threshold = 90

def __call__(self, document: Document):
Expand All @@ -31,11 +37,11 @@ def __call__(self, document: Document):
block = None
last_block = None
for block in page.contained_blocks(document, self.block_types):
if initial_block is None:
initial_block = block
if block.structure is not None:
if initial_block is None:
initial_block = block

if block is not None:
last_block = block
last_block = block

if initial_block is not None:
first_blocks.append(initial_block)
Expand All @@ -47,24 +53,32 @@ def __call__(self, document: Document):

@staticmethod
def clean_text(text):
return re.sub(r"\s+", "", text)
text = text.replace("\n", "").strip()
text = re.sub(r"^\d+\s*", "", text) # remove numbers at the start of the line
text = re.sub(r"\s*\d+$", "", text) # remove numbers at the end of the line
return text

def filter_common_elements(self, document, blocks):
def filter_common_elements(self, document, blocks: List[Block]):
# We can't filter if we don't have enough pages to find common elements
if len(blocks) < self.common_element_min_blocks:
return

text = [self.clean_text(b.raw_text(document)) for b in blocks]

streaks = {}
for key, group in groupby(text):
streaks[key] = max(streaks.get(key, 0), len(list(group)))

counter = Counter(text)
common = [
k for k, v in counter.items()
if v > len(blocks) * self.common_element_threshold
if (v >= len(blocks) * self.common_element_threshold or streaks[k] >= self.max_streak)
and v > self.common_element_min_blocks
]
if len(common) == 0:
return

for b in blocks:
if fuzz.ratio(self.clean_text(b.raw_text(document)), common[0]) > self.text_match_threshold:
for span in b.contained_blocks(document, [BlockTypes.Span]):
span.ignore_for_output = True
for t, b in zip(text, blocks):
# Check against all common elements
if any(fuzz.ratio(t, common_element) > self.text_match_threshold for common_element in common):
b.ignore_for_output = True
3 changes: 1 addition & 2 deletions marker/processors/line_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def ignore_line_number_blocks(self, document: Document):
sum(tokens_are_numbers) / len(tokens) > self.strip_numbers_threshold,
block.polygon.height > block.polygon.width # Ensure block is taller than it is wide, like vertical page numbers
]):
for span in block.contained_blocks(document, [BlockTypes.Span]):
span.ignore_for_output = True
block.ignore_for_output = True


def ignore_line_starts_ends(self, document: Document):
Expand Down
8 changes: 4 additions & 4 deletions marker/processors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,16 @@ def get_table_rec_batch_size(self):
if self.table_rec_batch_size is not None:
return self.table_rec_batch_size
elif settings.TORCH_DEVICE_MODEL == "mps":
return 16
return 6
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 64
return 8
return 6
return 6

def get_recognition_batch_size(self):
if self.recognition_batch_size is not None:
return self.recognition_batch_size
elif settings.TORCH_DEVICE_MODEL == "mps":
return 32
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 128
return 32
return 32
4 changes: 1 addition & 3 deletions marker/renderers/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def convert_p(self, el, text, *args):
if has_continuation:
if regex.compile(rf'.*[\p{{Ll}}|\d][{hyphens}]\s?$', regex.DOTALL).match(text): # handle hypenation across pages
return regex.split(rf"[{hyphens}]\s?$", text)[0]
if regex.search(r'[^\w\s]$', text): # Ends with non-word character and so we add a space after text, e.g "However,"
return f"{text} "
return text
return f"{text} "
return f"{text}\n\n" if text else "" # default convert_p behavior


Expand Down
10 changes: 8 additions & 2 deletions marker/schema/blocks/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Literal, Optional, Dict, Tuple, Sequence
from typing import TYPE_CHECKING, List, Literal, Optional, Dict, Sequence

from pydantic import BaseModel, ConfigDict, field_validator

Expand Down Expand Up @@ -64,6 +64,7 @@ class Block(BaseModel):
page_id: Optional[int] = None
text_extraction_method: Optional[Literal['pdftext', 'surya']] = None
structure: List[BlockId] | None = None # The top-level page structure, which is the block ids in order
ignore_for_output: bool = False # Whether this block should be ignored in output

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand All @@ -76,6 +77,8 @@ def id(self) -> BlockId:
)

def structure_blocks(self, document_page: Document | PageGroup) -> List[Block]:
if self.structure is None:
return []
return [document_page.get_block(block_id) for block_id in self.structure]

def add_structure(self, block: Block):
Expand Down Expand Up @@ -114,6 +117,9 @@ def raw_text(self, document: Document) -> str:
return text

def assemble_html(self, child_blocks: List[BlockOutput], parent_structure: Optional[List[str]] = None):
if self.ignore_for_output:
return ""

template = ""
for c in child_blocks:
template += f"<content-ref src='{c.id}'></content-ref>"
Expand All @@ -129,7 +135,7 @@ def assign_section_hierarchy(self, section_hierarchy):

return section_hierarchy

def contained_blocks(self, document: Document, block_types: Sequence[BlockTypes] = None):
def contained_blocks(self, document: Document, block_types: Sequence[BlockTypes] = None) -> List[Block]:
if self.structure is None:
return []

Expand Down
3 changes: 1 addition & 2 deletions marker/schema/blocks/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ class Form(Block):
cells: List[SpanTableCell] | None = None

def assemble_html(self, child_blocks, parent_structure=None):
return html_format(self.cells)

return str(html_format(self.cells))
3 changes: 3 additions & 0 deletions marker/schema/blocks/inlinemath.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ class InlineMath(Block):
has_continuation: bool = False

def assemble_html(self, child_blocks, parent_structure):
if self.ignore_for_output:
return ""

template = super().assemble_html(child_blocks, parent_structure)
template = template.replace("\n", " ")

Expand Down
3 changes: 3 additions & 0 deletions marker/schema/blocks/pagefooter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ class PageFooter(Block):
block_type: str = BlockTypes.PageFooter

def assemble_html(self, child_blocks, parent_structure):
if self.ignore_for_output:
return ""

template = super().assemble_html(child_blocks, parent_structure)
template = template.replace("\n", " ")
return f"<p>{template}</p>"
3 changes: 3 additions & 0 deletions marker/schema/blocks/pageheader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ class PageHeader(Block):
block_type: str = BlockTypes.PageHeader

def assemble_html(self, child_blocks, parent_structure):
if self.ignore_for_output:
return ""

template = super().assemble_html(child_blocks, parent_structure)
template = template.replace("\n", " ")
return f"<p>{template}</p>"
3 changes: 3 additions & 0 deletions marker/schema/blocks/sectionheader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ class SectionHeader(Block):
heading_level: int | None = None

def assemble_html(self, child_blocks, parent_structure):
if self.ignore_for_output:
return ""

template = super().assemble_html(child_blocks, parent_structure)
template = template.replace("\n", " ")
tag = f"h{self.heading_level}" if self.heading_level else "h2"
Expand Down
2 changes: 1 addition & 1 deletion marker/schema/blocks/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Table(Block):

def assemble_html(self, child_blocks, parent_structure=None):
if self.cells:
return html_format(self.cells)
return str(html_format(self.cells))
else:
template = super().assemble_html(child_blocks, parent_structure)
return f"<p>{template}</p>"
Loading

0 comments on commit 96d1b81

Please sign in to comment.