Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunk JSON output #371

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion marker/v2/builders/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_recognition_batch_size(self):
if self.recognition_batch_size is not None:
return self.recognition_batch_size
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 32
return 128
elif settings.TORCH_DEVICE_MODEL == "mps":
return 32
return 32
Expand Down
35 changes: 29 additions & 6 deletions marker/v2/converters/pdf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json

from marker.v2.providers.pdf import PdfProvider
import os

from marker.v2.renderers.json import JSONRenderer

os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning

import tempfile
Expand Down Expand Up @@ -30,7 +34,7 @@
class PdfConverter(BaseConverter):
override_map: Dict[BlockTypes, Type[Block]] = defaultdict()

def __init__(self, config=None):
def __init__(self, config=None, output_format="markdown"):
super().__init__(config)

for block_type, override_block_type in self.override_map.items():
Expand All @@ -42,6 +46,11 @@ def __init__(self, config=None):
self.table_rec_model = setup_table_rec_model()
self.detection_model = setup_detection_model()

if output_format == "markdown":
self.renderer = MarkdownRenderer(self.config)
elif output_format == "json":
self.renderer = JSONRenderer(self.config)

def __call__(self, filepath: str):
pdf_provider = PdfProvider(filepath, self.config)

Expand All @@ -60,18 +69,18 @@ def __call__(self, filepath: str):
for processor in processor_list:
processor(document)

renderer = MarkdownRenderer(self.config)
return renderer(document)
return self.renderer(document)


@click.command()
@click.option("--output", type=click.Path(exists=False), required=False, default="temp")
@click.option("--fname", type=str, default="adversarial.pdf")
@click.option("--debug", is_flag=True)
def main(output: str, fname: str, debug: bool):
@click.option("--output_format", type=click.Choice(["markdown", "json"]), default="markdown")
def main(output: str, fname: str, debug: bool, output_format: str):
dataset = datasets.load_dataset("datalab-to/pdfs", split="train")
idx = dataset['filename'].index(fname)
out_filename = fname.rsplit(".", 1)[0] + ".md"
fname_base = fname.rsplit(".", 1)[0]
os.makedirs(output, exist_ok=True)

config = {}
Expand All @@ -84,14 +93,28 @@ def main(output: str, fname: str, debug: bool):
temp_pdf.write(dataset['pdf'][idx])
temp_pdf.flush()

converter = PdfConverter(config)
converter = PdfConverter(config=config, output_format=output_format)
rendered = converter(temp_pdf.name)

if output_format == "markdown":
out_filename = f"{fname_base}.md"
with open(os.path.join(output, out_filename), "w+") as f:
f.write(rendered.markdown)

meta_filename = f"{fname_base}_meta.json"
with open(os.path.join(output, meta_filename), "w+") as f:
f.write(json.dumps(rendered.metadata, indent=2))

for img_name, img in rendered.images.items():
img.save(os.path.join(output, img_name), "PNG")
elif output_format == "json":
out_filename = f"{fname_base}.json"
with open(os.path.join(output, out_filename), "w+") as f:
f.write(rendered.model_dump_json(indent=2))

meta_filename = f"{fname_base}_meta.json"
with open(os.path.join(output, meta_filename), "w+") as f:
f.write(json.dumps(rendered.metadata, indent=2))


if __name__ == "__main__":
Expand Down
88 changes: 87 additions & 1 deletion marker/v2/renderers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,103 @@
import base64
import io
import re
from typing import Optional

from bs4 import BeautifulSoup
from pydantic import BaseModel

from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks.base import BlockOutput, BlockId
from marker.v2.util import assign_config


class BaseRenderer:
block_type: BlockTypes | None = None
remove_blocks: list = [BlockTypes.PageHeader, BlockTypes.PageFooter]
image_blocks: list = [BlockTypes.Picture, BlockTypes.Figure]

def __init__(self, config: Optional[BaseModel | dict] = None):
assign_config(self, config)

def __call__(self, document):
# Children are in reading order
raise NotImplementedError

@staticmethod
def extract_image(document, image_id, to_base64=False):
image_block = document.get_block(image_id)
page = document.get_page(image_block.page_id)
page_img = page.highres_image
image_box = image_block.polygon.rescale(page.polygon.size, page_img.size)
cropped = page_img.crop(image_box.bbox)
if to_base64:
image_buffer = io.BytesIO()
cropped.save(image_buffer, format='PNG')
cropped = base64.b64encode(image_buffer.getvalue()).decode('utf-8')
return cropped

@staticmethod
def merge_consecutive_tags(html, tag):
if not html:
return html

def replace_whitespace(match):
return match.group(1)

pattern = fr'</{tag}>(\s*)<{tag}>'

while True:
new_merged = re.sub(pattern, replace_whitespace, html)
if new_merged == html:
break
html = new_merged

return html

def compute_toc(self, document, block_output: BlockOutput):
toc = []
if hasattr(block_output, "id") and block_output.id.block_type == BlockTypes.SectionHeader:
toc.append({
"title": self.extract_block_html(document, block_output)[0],
"level": document.get_block(block_output.id).heading_level,
"page": block_output.id.page_id
})

for child in block_output.children:
child_toc = self.compute_toc(document, child)
if child_toc:
toc.extend(child_toc)
return toc

def generate_document_metadata(self, document, document_output):
toc = self.compute_toc(document, document_output)
return {
"table_of_contents": toc
}

def extract_block_html(self, document, block_output):
soup = BeautifulSoup(block_output.html, 'html.parser')

content_refs = soup.find_all('content-ref')
ref_block_id = None
images = {}
for ref in content_refs:
src = ref.get('src')
sub_images = {}
for item in block_output.children:
if item.id == src:
content, sub_images_ = self.extract_block_html(document, item)
sub_images.update(sub_images_)
ref_block_id: BlockId = item.id
break

if ref_block_id.block_type in self.image_blocks:
images[ref_block_id] = self.extract_image(document, ref_block_id, to_base64=True)
else:
images.update(sub_images)
ref.replace_with(BeautifulSoup(content, 'html.parser'))

if block_output.id.block_type in self.image_blocks:
images[block_output.id] = self.extract_image(document, block_output.id, to_base64=True)

return str(soup), images

29 changes: 6 additions & 23 deletions marker/v2/renderers/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,10 @@
class HTMLOutput(BaseModel):
html: str
images: dict


def merge_consecutive_tags(html, tag):
if not html:
return html

def replace_whitespace(match):
return match.group(1)

pattern = fr'</{tag}>(\s*)<{tag}>'

while True:
new_merged = re.sub(pattern, replace_whitespace, html)
if new_merged == html:
break
html = new_merged

return html
metadata: dict


class HTMLRenderer(BaseRenderer):
remove_blocks: list = [BlockTypes.PageHeader, BlockTypes.PageFooter]
image_blocks: list = [BlockTypes.Picture, BlockTypes.Figure]
page_blocks: list = [BlockTypes.Page]
paginate_output: bool = False

Expand All @@ -60,7 +41,8 @@ def extract_html(self, document, document_output, level=0):
sub_images = {}
for item in document_output.children:
if item.id == src:
content, sub_images = self.extract_html(document, item, level + 1)
content, sub_images_ = self.extract_html(document, item, level + 1)
sub_images.update(sub_images_)
ref_block_id: BlockId = item.id
break

Expand All @@ -82,8 +64,8 @@ def extract_html(self, document, document_output, level=0):

output = str(soup)
if level == 0:
output = merge_consecutive_tags(output, 'b')
output = merge_consecutive_tags(output, 'i')
output = self.merge_consecutive_tags(output, 'b')
output = self.merge_consecutive_tags(output, 'i')

return output, images

Expand All @@ -93,4 +75,5 @@ def __call__(self, document) -> HTMLOutput:
return HTMLOutput(
html=full_html,
images=images,
metadata=self.generate_document_metadata(document, document_output)
)
79 changes: 79 additions & 0 deletions marker/v2/renderers/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

import base64
import io
from typing import List, Dict

from bs4 import BeautifulSoup
from pydantic import BaseModel

from marker.v2.schema.blocks import Block
from marker.v2.renderers import BaseRenderer
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import BlockId
from marker.v2.schema.registry import get_block_class


class JSONBlockOutput(BaseModel):
id: str
block_type: str
html: str
polygon: List[List[float]]
children: List[JSONBlockOutput] | None = None
section_hierarchy: Dict[int, str] | None = None
images: dict | None = None


class JSONOutput(BaseModel):
children: List[JSONBlockOutput]
block_type: BlockTypes = BlockTypes.Document
metadata: dict


def reformat_section_hierarchy(section_hierarchy):
new_section_hierarchy = {}
for key, value in section_hierarchy.items():
new_section_hierarchy[key] = str(value)
return new_section_hierarchy


class JSONRenderer(BaseRenderer):
image_blocks: list = [BlockTypes.Picture, BlockTypes.Figure]
page_blocks: list = [BlockTypes.Page]

def extract_json(self, document, block_output):
cls = get_block_class(block_output.id.block_type)
if cls.__base__ == Block:
html, images = self.extract_block_html(document, block_output)
return JSONBlockOutput(
html=html,
polygon=block_output.polygon.polygon,
id=str(block_output.id),
block_type=str(block_output.id.block_type),
images=images,
section_hierarchy=reformat_section_hierarchy(block_output.section_hierarchy)
)
else:
children = []
for child in block_output.children:
child_output = self.extract_json(document, child)
children.append(child_output)

return JSONBlockOutput(
html=block_output.html,
polygon=block_output.polygon.polygon,
id=str(block_output.id),
block_type=str(block_output.id.block_type),
children=children,
section_hierarchy=reformat_section_hierarchy(block_output.section_hierarchy)
)

def __call__(self, document) -> JSONOutput:
document_output = document.render()
json_output = []
for page_output in document_output.children:
json_output.append(self.extract_json(document, page_output))
return JSONOutput(
children=json_output,
metadata=self.generate_document_metadata(document, document_output)
)
4 changes: 3 additions & 1 deletion marker/v2/renderers/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def convert_div(self, el, text, convert_as_inline):
class MarkdownOutput(BaseModel):
markdown: str
images: dict
metadata: dict


class MarkdownRenderer(HTMLRenderer):
Expand All @@ -43,5 +44,6 @@ def __call__(self, document: Document) -> MarkdownOutput:
markdown = md_cls.convert(full_html)
return MarkdownOutput(
markdown=markdown,
images=images
images=images,
metadata=self.generate_document_metadata(document, document_output)
)
Loading