Skip to content

Commit

Permalink
Merge pull request #205 from IndicoDataSolutions/mawelborn/etloutput-…
Browse files Browse the repository at this point in the history
…dataclasses

ETL Output Dataclasses
  • Loading branch information
mawelborn authored Dec 20, 2024
2 parents eb26c5f + 4379304 commit befb3d8
Show file tree
Hide file tree
Showing 55 changed files with 62,222 additions and 0 deletions.
201 changes: 201 additions & 0 deletions indico_toolkit/etloutput/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from typing import TYPE_CHECKING

from .cell import Cell, CellType
from .errors import EtlOutputError, TableCellNotFoundError, TokenNotFoundError
from .etloutput import EtlOutput
from .table import Table
from .token import Token
from .utilities import get, has

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from typing import Any

__all__ = (
"Cell",
"CellType",
"EtlOutput",
"EtlOutputError",
"load",
"load_async",
"Table",
"TableCellNotFoundError",
"Token",
"TokenNotFoundError",
)


def load(
etl_output_url: str,
*,
reader: "Callable[..., Any]",
text: bool = True,
tokens: bool = True,
tables: bool = False,
) -> EtlOutput:
"""
Load `etl_output_url` as an ETL Output dataclass. A `reader` function must be
supplied to read JSON files from disk, storage API, or Indico client.
Use `text`, `tokens`, and `tables` to specify what to load.
```
result = results.load(submission.result_file, reader=read_url)
etl_outputs = {
document: etloutput.load(document.etl_output_url, reader=read_url)
for document in result.documents
}
```
"""
etl_output = reader(etl_output_url)
tables_url = etl_output_url.replace("etl_output.json", "tables.json")

if has(etl_output, str, "pages", 0, "page_info"):
return _load_v1(etl_output, tables_url, reader, text, tokens, tables)
else:
return _load_v3(etl_output, tables_url, reader, text, tokens, tables)


async def load_async(
etl_output_url: str,
*,
reader: "Callable[..., Awaitable[Any]]",
text: bool = True,
tokens: bool = True,
tables: bool = False,
) -> EtlOutput:
"""
Load `etl_output_url` as an ETL Output dataclass. A `reader` coroutine must be
supplied to read JSON files from disk, storage API, or Indico client.
Use `text`, `tokens`, and `tables` to specify what to load.
```
result = await results.load_async(submission.result_file, reader=read_url)
etl_outputs = {
document: await etloutput.load_async(document.etl_output_url, reader=read_url)
for document in result.documents
}
```
"""
etl_output = await reader(etl_output_url)
tables_url = etl_output_url.replace("etl_output.json", "tables.json")

if has(etl_output, str, "pages", 0, "page_info"):
return await _load_v1_async(
etl_output, tables_url, reader, text, tokens, tables
)
else:
return await _load_v3_async(
etl_output, tables_url, reader, text, tokens, tables
)


def _load_v1(
etl_output: "Any",
tables_url: str,
reader: "Callable[..., Any]",
text: bool,
tokens: bool,
tables: bool,
) -> EtlOutput:
if text or tokens:
pages = tuple(
reader(get(page, str, "page_info"))
for page in get(etl_output, list, "pages")
)
text_by_page = map(lambda page: get(page, str, "pages", 0, "text"), pages)
tokens_by_page = map(lambda page: get(page, list, "tokens"), pages)
else:
text_by_page = () # type: ignore[assignment]
tokens_by_page = () # type: ignore[assignment]

if tables:
tables_by_page = reader(tables_url)
else:
tables_by_page = ()

return EtlOutput.from_pages(text_by_page, tokens_by_page, tables_by_page)


def _load_v3(
etl_output: "Any",
tables_url: str,
reader: "Callable[..., Any]",
text: bool,
tokens: bool,
tables: bool,
) -> EtlOutput:
pages = get(etl_output, list, "pages")

if text or tokens:
text_by_page = map(lambda page: reader(get(page, str, "text")), pages)
else:
text_by_page = () # type: ignore[assignment]

if tokens:
tokens_by_page = map(lambda page: reader(get(page, str, "tokens")), pages)
else:
tokens_by_page = () # type: ignore[assignment]

if tables:
tables_by_page = reader(tables_url)
else:
tables_by_page = ()

return EtlOutput.from_pages(text_by_page, tokens_by_page, tables_by_page)


async def _load_v1_async(
etl_output: "Any",
tables_url: str,
reader: "Callable[..., Awaitable[Any]]",
text: bool,
tokens: bool,
tables: bool,
) -> EtlOutput:
if text or tokens:
pages = [
await reader(get(page, str, "page_info"))
for page in get(etl_output, list, "pages")
]
text_by_page = map(lambda page: get(page, str, "pages", 0, "text"), pages)
tokens_by_page = map(lambda page: get(page, list, "tokens"), pages)
else:
text_by_page = () # type: ignore[assignment]
tokens_by_page = () # type: ignore[assignment]

if tables:
tables_by_page = await reader(tables_url)
else:
tables_by_page = ()

return EtlOutput.from_pages(text_by_page, tokens_by_page, tables_by_page)


async def _load_v3_async(
etl_output: "Any",
tables_url: str,
reader: "Callable[..., Awaitable[Any]]",
text: bool,
tokens: bool,
tables: bool,
) -> EtlOutput:
pages = get(etl_output, list, "pages")

if text or tokens:
text_by_page = [await reader(get(page, str, "text")) for page in pages]
else:
text_by_page = () # type: ignore[assignment]

if tokens:
tokens_by_page = [await reader(get(page, str, "tokens")) for page in pages]
else:
tokens_by_page = () # type: ignore[assignment]

if tables:
tables_by_page = await reader(tables_url)
else:
tables_by_page = ()

return EtlOutput.from_pages(text_by_page, tokens_by_page, tables_by_page)
72 changes: 72 additions & 0 deletions indico_toolkit/etloutput/cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from dataclasses import dataclass
from enum import Enum

from .utilities import get, has


class CellType(Enum):
HEADER = "header"
CONTENT = "content"


@dataclass(frozen=True)
class Cell:
type: CellType
text: str
# Span
start: int
end: int
# Bounding box
page: int
top: int
left: int
right: int
bottom: int
# Table coordinates
row: int
rowspan: int
rows: "tuple[int, ...]"
column: int
columnspan: int
columns: "tuple[int, ...]"

def __lt__(self, other: "Cell") -> bool:
"""
By default, cells are sorted in table order (by row, then column).
Cells can also be sorted in span order: `tokens.sort(key=attrgetter("start"))`.
"""
return self.row < other.row or (
self.row == other.row and self.column < other.column
)

@staticmethod
def from_dict(cell: object, page: int) -> "Cell":
"""
Create a `Cell` from a v1 or v3 ETL Ouput cell dictionary.
"""
return Cell(
type=CellType(get(cell, str, "cell_type")),
text=get(cell, str, "text"),
# Empty cells have no start and end; so use [0:0] for a valid slice.
start=(
get(cell, int, "doc_offsets", 0, "start")
if has(cell, int, "doc_offsets", 0, "start")
else 0
),
end=(
get(cell, int, "doc_offsets", 0, "end")
if has(cell, int, "doc_offsets", 0, "end")
else 0
),
page=page,
top=get(cell, int, "position", "top"),
left=get(cell, int, "position", "left"),
right=get(cell, int, "position", "right"),
bottom=get(cell, int, "position", "bottom"),
row=get(cell, int, "rows", 0),
rowspan=len(get(cell, list, "rows")),
rows=tuple(get(cell, list, "rows")),
column=get(cell, int, "columns", 0),
columnspan=len(get(cell, list, "columns")),
columns=tuple(get(cell, list, "columns")),
)
16 changes: 16 additions & 0 deletions indico_toolkit/etloutput/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class EtlOutputError(Exception):
"""
Raised when an error occurs while loading an ETL Output file.
"""


class TokenNotFoundError(EtlOutputError):
"""
Raised when a Token can't be found for a Prediction.
"""


class TableCellNotFoundError(EtlOutputError):
"""
Raised when a Table Cell can't be found for a Token.
"""
106 changes: 106 additions & 0 deletions indico_toolkit/etloutput/etloutput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import itertools
from bisect import bisect_left, bisect_right
from dataclasses import dataclass
from operator import attrgetter
from typing import TYPE_CHECKING

from .errors import TableCellNotFoundError, TokenNotFoundError
from .table import Table
from .token import Token

if TYPE_CHECKING:
from collections.abc import Iterable

from ..results import DocumentExtraction
from .cell import Cell


@dataclass(frozen=True)
class EtlOutput:
text: str
text_on_page: "tuple[str, ...]"

tokens: "tuple[Token, ...]"
tokens_on_page: "tuple[tuple[Token, ...], ...]"

tables: "tuple[Table, ...]"
tables_on_page: "tuple[tuple[Table, ...], ...]"

@staticmethod
def from_pages(
text_by_page: "Iterable[str]",
token_dicts_by_page: "Iterable[Iterable[object]]",
table_dicts_by_page: "Iterable[Iterable[object]]",
) -> "EtlOutput":
"""
Create an `EtlOutput` from v1 or v3 ETL Ouput pages.
"""
text_by_page = tuple(text_by_page)
tokens_by_page = tuple(
tuple(map(Token.from_dict, token_dict_page))
for token_dict_page in token_dicts_by_page
)
tables_by_page = tuple(
tuple(map(Table.from_dict, table_dict_page))
for table_dict_page in table_dicts_by_page
)

return EtlOutput(
text="\n".join(text_by_page),
text_on_page=text_by_page,
tokens=tuple(itertools.chain.from_iterable(tokens_by_page)),
tokens_on_page=tokens_by_page,
tables=tuple(itertools.chain.from_iterable(tables_by_page)),
tables_on_page=tables_by_page,
)

def token_for(self, extraction: "DocumentExtraction") -> Token:
"""
Return a `Token` that contains every character from `extraction`.
Raise `TokenNotFoundError` if one can't be produced.
"""
try:
tokens = self.tokens_on_page[extraction.page]
first = bisect_right(tokens, extraction.start, key=attrgetter("end"))
last = bisect_left(tokens, extraction.end, lo=first, key=attrgetter("start")) # fmt: skip # noqa: E501
tokens = tokens[first:last]

return Token(
text=self.text[extraction.start : extraction.end],
start=extraction.start,
end=extraction.end,
page=min(token.page for token in tokens),
top=min(token.top for token in tokens),
left=min(token.left for token in tokens),
right=max(token.right for token in tokens),
bottom=max(token.bottom for token in tokens),
)
except (IndexError, ValueError) as error:
raise TokenNotFoundError(f"no token contains {extraction!r}") from error

def table_cell_for(self, token: Token) -> "tuple[Table, Cell]":
"""
Return the `Table` and `Cell` that contain the midpoint of `token`.
Raise `TableCellNotFoundError` if it's not inside a table cell.
"""
token_vmid = (token.top + token.bottom) // 2
token_hmid = (token.left + token.right) // 2

for table in self.tables_on_page[token.page]:
if (table.top <= token_vmid <= table.bottom) and (
table.left <= token_hmid <= table.right
):
break
else:
raise TableCellNotFoundError(f"no table contains {token!r}")

try:
row_index = bisect_left(table.rows, token_vmid, key=lambda row: row[0].bottom) # fmt: skip # noqa: E501
row = table.rows[row_index]

cell_index = bisect_left(row, token_hmid, key=attrgetter("right"))
cell = row[cell_index]
except (IndexError, ValueError) as error:
raise TableCellNotFoundError(f"no cell contains {token!r}") from error

return table, cell
Loading

0 comments on commit befb3d8

Please sign in to comment.