-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #205 from IndicoDataSolutions/mawelborn/etloutput-…
…dataclasses ETL Output Dataclasses
- Loading branch information
Showing
55 changed files
with
62,222 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
from .cell import Cell, CellType | ||
from .errors import EtlOutputError, TableCellNotFoundError, TokenNotFoundError | ||
from .etloutput import EtlOutput | ||
from .table import Table | ||
from .token import Token | ||
from .utilities import get, has | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Awaitable, Callable | ||
from typing import Any | ||
|
||
__all__ = ( | ||
"Cell", | ||
"CellType", | ||
"EtlOutput", | ||
"EtlOutputError", | ||
"load", | ||
"load_async", | ||
"Table", | ||
"TableCellNotFoundError", | ||
"Token", | ||
"TokenNotFoundError", | ||
) | ||
|
||
|
||
def load( | ||
etl_output_url: str, | ||
*, | ||
reader: "Callable[..., Any]", | ||
text: bool = True, | ||
tokens: bool = True, | ||
tables: bool = False, | ||
) -> EtlOutput: | ||
""" | ||
Load `etl_output_url` as an ETL Output dataclass. A `reader` function must be | ||
supplied to read JSON files from disk, storage API, or Indico client. | ||
Use `text`, `tokens`, and `tables` to specify what to load. | ||
``` | ||
result = results.load(submission.result_file, reader=read_url) | ||
etl_outputs = { | ||
document: etloutput.load(document.etl_output_url, reader=read_url) | ||
for document in result.documents | ||
} | ||
``` | ||
""" | ||
etl_output = reader(etl_output_url) | ||
tables_url = etl_output_url.replace("etl_output.json", "tables.json") | ||
|
||
if has(etl_output, str, "pages", 0, "page_info"): | ||
return _load_v1(etl_output, tables_url, reader, text, tokens, tables) | ||
else: | ||
return _load_v3(etl_output, tables_url, reader, text, tokens, tables) | ||
|
||
|
||
async def load_async( | ||
etl_output_url: str, | ||
*, | ||
reader: "Callable[..., Awaitable[Any]]", | ||
text: bool = True, | ||
tokens: bool = True, | ||
tables: bool = False, | ||
) -> EtlOutput: | ||
""" | ||
Load `etl_output_url` as an ETL Output dataclass. A `reader` coroutine must be | ||
supplied to read JSON files from disk, storage API, or Indico client. | ||
Use `text`, `tokens`, and `tables` to specify what to load. | ||
``` | ||
result = await results.load_async(submission.result_file, reader=read_url) | ||
etl_outputs = { | ||
document: await etloutput.load_async(document.etl_output_url, reader=read_url) | ||
for document in result.documents | ||
} | ||
``` | ||
""" | ||
etl_output = await reader(etl_output_url) | ||
tables_url = etl_output_url.replace("etl_output.json", "tables.json") | ||
|
||
if has(etl_output, str, "pages", 0, "page_info"): | ||
return await _load_v1_async( | ||
etl_output, tables_url, reader, text, tokens, tables | ||
) | ||
else: | ||
return await _load_v3_async( | ||
etl_output, tables_url, reader, text, tokens, tables | ||
) | ||
|
||
|
||
def _load_v1( | ||
etl_output: "Any", | ||
tables_url: str, | ||
reader: "Callable[..., Any]", | ||
text: bool, | ||
tokens: bool, | ||
tables: bool, | ||
) -> EtlOutput: | ||
if text or tokens: | ||
pages = tuple( | ||
reader(get(page, str, "page_info")) | ||
for page in get(etl_output, list, "pages") | ||
) | ||
text_by_page = map(lambda page: get(page, str, "pages", 0, "text"), pages) | ||
tokens_by_page = map(lambda page: get(page, list, "tokens"), pages) | ||
else: | ||
text_by_page = () # type: ignore[assignment] | ||
tokens_by_page = () # type: ignore[assignment] | ||
|
||
if tables: | ||
tables_by_page = reader(tables_url) | ||
else: | ||
tables_by_page = () | ||
|
||
return EtlOutput.from_pages(text_by_page, tokens_by_page, tables_by_page) | ||
|
||
|
||
def _load_v3( | ||
etl_output: "Any", | ||
tables_url: str, | ||
reader: "Callable[..., Any]", | ||
text: bool, | ||
tokens: bool, | ||
tables: bool, | ||
) -> EtlOutput: | ||
pages = get(etl_output, list, "pages") | ||
|
||
if text or tokens: | ||
text_by_page = map(lambda page: reader(get(page, str, "text")), pages) | ||
else: | ||
text_by_page = () # type: ignore[assignment] | ||
|
||
if tokens: | ||
tokens_by_page = map(lambda page: reader(get(page, str, "tokens")), pages) | ||
else: | ||
tokens_by_page = () # type: ignore[assignment] | ||
|
||
if tables: | ||
tables_by_page = reader(tables_url) | ||
else: | ||
tables_by_page = () | ||
|
||
return EtlOutput.from_pages(text_by_page, tokens_by_page, tables_by_page) | ||
|
||
|
||
async def _load_v1_async( | ||
etl_output: "Any", | ||
tables_url: str, | ||
reader: "Callable[..., Awaitable[Any]]", | ||
text: bool, | ||
tokens: bool, | ||
tables: bool, | ||
) -> EtlOutput: | ||
if text or tokens: | ||
pages = [ | ||
await reader(get(page, str, "page_info")) | ||
for page in get(etl_output, list, "pages") | ||
] | ||
text_by_page = map(lambda page: get(page, str, "pages", 0, "text"), pages) | ||
tokens_by_page = map(lambda page: get(page, list, "tokens"), pages) | ||
else: | ||
text_by_page = () # type: ignore[assignment] | ||
tokens_by_page = () # type: ignore[assignment] | ||
|
||
if tables: | ||
tables_by_page = await reader(tables_url) | ||
else: | ||
tables_by_page = () | ||
|
||
return EtlOutput.from_pages(text_by_page, tokens_by_page, tables_by_page) | ||
|
||
|
||
async def _load_v3_async( | ||
etl_output: "Any", | ||
tables_url: str, | ||
reader: "Callable[..., Awaitable[Any]]", | ||
text: bool, | ||
tokens: bool, | ||
tables: bool, | ||
) -> EtlOutput: | ||
pages = get(etl_output, list, "pages") | ||
|
||
if text or tokens: | ||
text_by_page = [await reader(get(page, str, "text")) for page in pages] | ||
else: | ||
text_by_page = () # type: ignore[assignment] | ||
|
||
if tokens: | ||
tokens_by_page = [await reader(get(page, str, "tokens")) for page in pages] | ||
else: | ||
tokens_by_page = () # type: ignore[assignment] | ||
|
||
if tables: | ||
tables_by_page = await reader(tables_url) | ||
else: | ||
tables_by_page = () | ||
|
||
return EtlOutput.from_pages(text_by_page, tokens_by_page, tables_by_page) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
||
from .utilities import get, has | ||
|
||
|
||
class CellType(Enum): | ||
HEADER = "header" | ||
CONTENT = "content" | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Cell: | ||
type: CellType | ||
text: str | ||
# Span | ||
start: int | ||
end: int | ||
# Bounding box | ||
page: int | ||
top: int | ||
left: int | ||
right: int | ||
bottom: int | ||
# Table coordinates | ||
row: int | ||
rowspan: int | ||
rows: "tuple[int, ...]" | ||
column: int | ||
columnspan: int | ||
columns: "tuple[int, ...]" | ||
|
||
def __lt__(self, other: "Cell") -> bool: | ||
""" | ||
By default, cells are sorted in table order (by row, then column). | ||
Cells can also be sorted in span order: `tokens.sort(key=attrgetter("start"))`. | ||
""" | ||
return self.row < other.row or ( | ||
self.row == other.row and self.column < other.column | ||
) | ||
|
||
@staticmethod | ||
def from_dict(cell: object, page: int) -> "Cell": | ||
""" | ||
Create a `Cell` from a v1 or v3 ETL Ouput cell dictionary. | ||
""" | ||
return Cell( | ||
type=CellType(get(cell, str, "cell_type")), | ||
text=get(cell, str, "text"), | ||
# Empty cells have no start and end; so use [0:0] for a valid slice. | ||
start=( | ||
get(cell, int, "doc_offsets", 0, "start") | ||
if has(cell, int, "doc_offsets", 0, "start") | ||
else 0 | ||
), | ||
end=( | ||
get(cell, int, "doc_offsets", 0, "end") | ||
if has(cell, int, "doc_offsets", 0, "end") | ||
else 0 | ||
), | ||
page=page, | ||
top=get(cell, int, "position", "top"), | ||
left=get(cell, int, "position", "left"), | ||
right=get(cell, int, "position", "right"), | ||
bottom=get(cell, int, "position", "bottom"), | ||
row=get(cell, int, "rows", 0), | ||
rowspan=len(get(cell, list, "rows")), | ||
rows=tuple(get(cell, list, "rows")), | ||
column=get(cell, int, "columns", 0), | ||
columnspan=len(get(cell, list, "columns")), | ||
columns=tuple(get(cell, list, "columns")), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
class EtlOutputError(Exception): | ||
""" | ||
Raised when an error occurs while loading an ETL Output file. | ||
""" | ||
|
||
|
||
class TokenNotFoundError(EtlOutputError): | ||
""" | ||
Raised when a Token can't be found for a Prediction. | ||
""" | ||
|
||
|
||
class TableCellNotFoundError(EtlOutputError): | ||
""" | ||
Raised when a Table Cell can't be found for a Token. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import itertools | ||
from bisect import bisect_left, bisect_right | ||
from dataclasses import dataclass | ||
from operator import attrgetter | ||
from typing import TYPE_CHECKING | ||
|
||
from .errors import TableCellNotFoundError, TokenNotFoundError | ||
from .table import Table | ||
from .token import Token | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Iterable | ||
|
||
from ..results import DocumentExtraction | ||
from .cell import Cell | ||
|
||
|
||
@dataclass(frozen=True) | ||
class EtlOutput: | ||
text: str | ||
text_on_page: "tuple[str, ...]" | ||
|
||
tokens: "tuple[Token, ...]" | ||
tokens_on_page: "tuple[tuple[Token, ...], ...]" | ||
|
||
tables: "tuple[Table, ...]" | ||
tables_on_page: "tuple[tuple[Table, ...], ...]" | ||
|
||
@staticmethod | ||
def from_pages( | ||
text_by_page: "Iterable[str]", | ||
token_dicts_by_page: "Iterable[Iterable[object]]", | ||
table_dicts_by_page: "Iterable[Iterable[object]]", | ||
) -> "EtlOutput": | ||
""" | ||
Create an `EtlOutput` from v1 or v3 ETL Ouput pages. | ||
""" | ||
text_by_page = tuple(text_by_page) | ||
tokens_by_page = tuple( | ||
tuple(map(Token.from_dict, token_dict_page)) | ||
for token_dict_page in token_dicts_by_page | ||
) | ||
tables_by_page = tuple( | ||
tuple(map(Table.from_dict, table_dict_page)) | ||
for table_dict_page in table_dicts_by_page | ||
) | ||
|
||
return EtlOutput( | ||
text="\n".join(text_by_page), | ||
text_on_page=text_by_page, | ||
tokens=tuple(itertools.chain.from_iterable(tokens_by_page)), | ||
tokens_on_page=tokens_by_page, | ||
tables=tuple(itertools.chain.from_iterable(tables_by_page)), | ||
tables_on_page=tables_by_page, | ||
) | ||
|
||
def token_for(self, extraction: "DocumentExtraction") -> Token: | ||
""" | ||
Return a `Token` that contains every character from `extraction`. | ||
Raise `TokenNotFoundError` if one can't be produced. | ||
""" | ||
try: | ||
tokens = self.tokens_on_page[extraction.page] | ||
first = bisect_right(tokens, extraction.start, key=attrgetter("end")) | ||
last = bisect_left(tokens, extraction.end, lo=first, key=attrgetter("start")) # fmt: skip # noqa: E501 | ||
tokens = tokens[first:last] | ||
|
||
return Token( | ||
text=self.text[extraction.start : extraction.end], | ||
start=extraction.start, | ||
end=extraction.end, | ||
page=min(token.page for token in tokens), | ||
top=min(token.top for token in tokens), | ||
left=min(token.left for token in tokens), | ||
right=max(token.right for token in tokens), | ||
bottom=max(token.bottom for token in tokens), | ||
) | ||
except (IndexError, ValueError) as error: | ||
raise TokenNotFoundError(f"no token contains {extraction!r}") from error | ||
|
||
def table_cell_for(self, token: Token) -> "tuple[Table, Cell]": | ||
""" | ||
Return the `Table` and `Cell` that contain the midpoint of `token`. | ||
Raise `TableCellNotFoundError` if it's not inside a table cell. | ||
""" | ||
token_vmid = (token.top + token.bottom) // 2 | ||
token_hmid = (token.left + token.right) // 2 | ||
|
||
for table in self.tables_on_page[token.page]: | ||
if (table.top <= token_vmid <= table.bottom) and ( | ||
table.left <= token_hmid <= table.right | ||
): | ||
break | ||
else: | ||
raise TableCellNotFoundError(f"no table contains {token!r}") | ||
|
||
try: | ||
row_index = bisect_left(table.rows, token_vmid, key=lambda row: row[0].bottom) # fmt: skip # noqa: E501 | ||
row = table.rows[row_index] | ||
|
||
cell_index = bisect_left(row, token_hmid, key=attrgetter("right")) | ||
cell = row[cell_index] | ||
except (IndexError, ValueError) as error: | ||
raise TableCellNotFoundError(f"no cell contains {token!r}") from error | ||
|
||
return table, cell |
Oops, something went wrong.