Skip to content

Commit

Permalink
Merge pull request #201 from IndicoDataSolutions/mawelborn/extraction…
Browse files Browse the repository at this point in the history
…-types

Result File Dataclasses: Improve Extraction API
  • Loading branch information
mawelborn authored Oct 4, 2024
2 parents 4dd0202 + 5bd92c7 commit 2f97bb4
Show file tree
Hide file tree
Showing 13 changed files with 290 additions and 215 deletions.
37 changes: 34 additions & 3 deletions examples/results_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Overview of dataclasses and functionality available in the results module.
"""

from operator import attrgetter
from pathlib import Path

Expand Down Expand Up @@ -116,14 +117,44 @@
# Extraction Dataclass (Subclass of Prediction)
extraction = predictions.extractions[0]
extraction.text
extraction.start
extraction.end
extraction.page
extraction.groups # Any linked label groups this prediction is a part of
extraction.accepted
extraction.rejected

extraction.accept() # Mark this extraction as accepted for auto review
extraction.reject() # Mark this extraction as rejected for auto review
extraction.unaccept() # Mark this extraction as not accepted for auto review
extraction.unreject() # Mark this extraction as not rejected for auto review


# DocumentExtraction Dataclass (Subclass of Extraction)
document_extraction = predictions.document_extractions[0]
document_extraction.text
document_extraction.page
document_extraction.start
document_extraction.end
document_extraction.groups # Any linked label groups this prediction is a part of
document_extraction.accepted
document_extraction.rejected

document_extraction.accept() # Mark this extraction as accepted for auto review
document_extraction.reject() # Mark this extraction as rejected for auto review
document_extraction.unaccept() # Mark this extraction as not accepted for auto review
document_extraction.unreject() # Mark this extraction as not rejected for auto review


# FormExtraction Dataclass (Subclass of Extraction)
form_extraction = predictions.form_extractions[0]
form_extraction.text
form_extraction.page
form_extraction.top
form_extraction.left
form_extraction.right
form_extraction.bottom
form_extraction.accepted
form_extraction.rejected

form_extraction.accept() # Mark this extraction as accepted for auto review
form_extraction.reject() # Mark this extraction as rejected for auto review
form_extraction.unaccept() # Mark this extraction as not accepted for auto review
form_extraction.unreject() # Mark this extraction as not rejected for auto review
11 changes: 7 additions & 4 deletions examples/results_to_csv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Dump classifications and extractions from result files to a CSV file.
"""

import csv
from collections.abc import Iterable, Iterator
from pathlib import Path
Expand All @@ -14,16 +15,16 @@ def final_predictions(result: results.Result) -> Iterator[dict[str, object]]:
for prediction in result.final:
if isinstance(prediction, results.Classification):
yield {
"submission_id": result.id,
"submission_id": result.submission_id,
"document_id": prediction.document.id,
"model": prediction.model.name,
"field": "Classification",
"value": prediction.label,
"confidence": prediction.confidence,
}
else:
elif isinstance(prediction, results.Extraction):
yield {
"submission_id": result.id,
"submission_id": result.submission_id,
"document_id": prediction.document.id,
"model": prediction.model.name,
"field": prediction.label,
Expand Down Expand Up @@ -57,7 +58,9 @@ def convert_to_csv(
"value",
"confidence",
],
).writerows(predictions_from_files(result_files))
).writerows(
predictions_from_files(result_files),
)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions indico_toolkit/results/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from .lists import PredictionList
from .models import ModelGroup, TaskType
from .predictions import (
AutoReviewable,
Classification,
DocumentExtraction,
Extraction,
FormExtraction,
FormExtractionType,
Expand All @@ -20,9 +20,9 @@
from .utils import get

__all__ = (
"AutoReviewable",
"Classification",
"Document",
"DocumentExtraction",
"Extraction",
"FormExtraction",
"FormExtractionType",
Expand Down
54 changes: 41 additions & 13 deletions indico_toolkit/results/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from .models import TaskType
from .predictions import (
AutoReviewable,
Classification,
DocumentExtraction,
Extraction,
FormExtraction,
Prediction,
Expand All @@ -15,7 +15,7 @@
from .utils import nfilter

if TYPE_CHECKING:
from collections.abc import Callable, Container, Iterable
from collections.abc import Callable, Collection, Container, Iterable
from typing import Any, SupportsIndex

from typing_extensions import Self
Expand All @@ -39,6 +39,10 @@ class PredictionList(List[PredictionType]):
def classifications(self) -> "PredictionList[Classification]":
return self.oftype(Classification)

@property
def document_extractions(self) -> "PredictionList[DocumentExtraction]":
return self.oftype(DocumentExtraction)

@property
def extractions(self) -> "PredictionList[Extraction]":
return self.oftype(Extraction)
Expand Down Expand Up @@ -114,6 +118,8 @@ def where(
label_in: "Container[str] | None" = None,
min_confidence: "float | None" = None,
max_confidence: "float | None" = None,
page: "int | None" = None,
page_in: "Collection[int] | None" = None,
accepted: "bool | None" = None,
rejected: "bool | None" = None,
checked: "bool | None" = None,
Expand All @@ -131,7 +137,9 @@ def where(
label_in: predictions with one of these labels,
min_confidence: predictions with confidence >= this threshold,
max_confidence: predictions with confidence <= this threshold,
accepted: extractions that have accepted,
page: extractions/unbundlings on this page,
page_in: extractions/unbundlings on one of these pages,
accepted: extractions that have been accepted,
rejected: extractions that have been rejected,
checked: form extractions that are checked,
signed: form extractions that are signed,
Expand Down Expand Up @@ -180,15 +188,35 @@ def where(
lambda prediction: prediction.confidence <= max_confidence
)

if page is not None:
predicates.append(
lambda prediction: (
(isinstance(prediction, Extraction) and prediction.page == page)
or (isinstance(prediction, Unbundling) and page in prediction.pages)
)
)

if page_in is not None:
page_in = set(page_in)
predicates.append(
lambda prediction: (
(isinstance(prediction, Extraction) and prediction.page in page_in)
or (
isinstance(prediction, Unbundling)
and bool(page_in & set(prediction.pages))
)
)
)

if accepted is not None:
predicates.append(
lambda prediction: isinstance(prediction, AutoReviewable)
lambda prediction: isinstance(prediction, Extraction)
and prediction.accepted == accepted
)

if rejected is not None:
predicates.append(
lambda prediction: isinstance(prediction, AutoReviewable)
lambda prediction: isinstance(prediction, Extraction)
and prediction.rejected == rejected
)

Expand All @@ -208,30 +236,30 @@ def where(

def accept(self) -> "Self":
"""
Mark predictions as accepted for auto-review.
Mark extractions as accepted for auto review.
"""
self.oftype(AutoReviewable).apply(AutoReviewable.accept)
self.oftype(Extraction).apply(Extraction.accept)
return self

def unaccept(self) -> "Self":
"""
Mark predictions as not accepted for auto-review.
Mark extractions as not accepted for auto review.
"""
self.oftype(AutoReviewable).apply(AutoReviewable.unaccept)
self.oftype(Extraction).apply(Extraction.unaccept)
return self

def reject(self) -> "Self":
"""
Mark predictions as rejected for auto-review.
Mark extractions as rejected for auto review.
"""
self.oftype(AutoReviewable).apply(AutoReviewable.reject)
self.oftype(Extraction).apply(Extraction.reject)
return self

def unreject(self) -> "Self":
"""
Mark predictions as not rejected for auto-review.
Mark extractions as not rejected for auto review.
"""
self.oftype(AutoReviewable).apply(AutoReviewable.unreject)
self.oftype(Extraction).apply(Extraction.unreject)
return self

def to_changes(self, result: "Result") -> "Any":
Expand Down
6 changes: 3 additions & 3 deletions indico_toolkit/results/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class TaskType(Enum):
CLASSIFICATION = "classification"
EXTRACTION = "annotation"
DOCUMENT_EXTRACTION = "annotation"
FORM_EXTRACTION = "form_extraction"
UNBUNDLING = "classification_unbundling"

Expand All @@ -31,12 +31,12 @@ def from_v1_section(section: "tuple[str, object]") -> "ModelGroup":
if has(prediction, str, "type"):
task_type = TaskType.FORM_EXTRACTION
elif has(prediction, str, "text"):
task_type = TaskType.EXTRACTION
task_type = TaskType.DOCUMENT_EXTRACTION
else:
task_type = TaskType.CLASSIFICATION
else:
# Likely an extraction model that produced no predictions.
task_type = TaskType.EXTRACTION
task_type = TaskType.DOCUMENT_EXTRACTION

return ModelGroup(
# v1 result files don't include model IDs.
Expand Down
8 changes: 4 additions & 4 deletions indico_toolkit/results/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def normalize_v1_result(result: "Any") -> None:
prediction["right"] = 0
prediction["bottom"] = 0

# Prior to 6.11, some extractions lack a `normalized` section after review.
# Prior to 6.11, some Extractions lack a `normalized` section after review.
if "text" in prediction and "normalized" not in prediction:
prediction["normalized"] = {"formatted": prediction["text"]}

# Document extractions that didn't go through a linked labels transformer
# Document Extractions that didn't go through a linked labels transformer
# lack a `groupings` section.
if (
"text" in prediction
Expand Down Expand Up @@ -123,12 +123,12 @@ def normalize_v3_result(result: "Any") -> None:
prediction["right"] = 0
prediction["bottom"] = 0

# Prior to 6.11, some extractions lack a `normalized` section after
# Prior to 6.11, some Extractions lack a `normalized` section after
# review.
if "text" in prediction and "normalized" not in prediction:
prediction["normalized"] = {"formatted": prediction["text"]}

# Document extractions that didn't go through a linked labels
# Document Extractions that didn't go through a linked labels
# transformer lack a `groupings` section.
if (
"text" in prediction
Expand Down
12 changes: 6 additions & 6 deletions indico_toolkit/results/predictions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING

from ..models import TaskType
from .autoreviewable import AutoReviewable
from .classifications import Classification
from .documentextractions import DocumentExtraction
from .extractions import Extraction
from .formextractions import FormExtraction, FormExtractionType
from .groups import Group
Expand All @@ -16,8 +16,8 @@
from ..reviews import Review

__all__ = (
"AutoReviewable",
"Classification",
"DocumentExtraction",
"Extraction",
"FormExtraction",
"FormExtractionType",
Expand All @@ -38,8 +38,8 @@ def from_v1_dict(
"""
if model.task_type == TaskType.CLASSIFICATION:
return Classification.from_v1_dict(document, model, review, prediction)
elif model.task_type == TaskType.EXTRACTION:
return Extraction.from_v1_dict(document, model, review, prediction)
elif model.task_type == TaskType.DOCUMENT_EXTRACTION:
return DocumentExtraction.from_v1_dict(document, model, review, prediction)
elif model.task_type == TaskType.FORM_EXTRACTION:
return FormExtraction.from_v1_dict(document, model, review, prediction)
else:
Expand All @@ -57,8 +57,8 @@ def from_v3_dict(
"""
if model.task_type == TaskType.CLASSIFICATION:
return Classification.from_v3_dict(document, model, review, prediction)
elif model.task_type == TaskType.EXTRACTION:
return Extraction.from_v3_dict(document, model, review, prediction)
elif model.task_type == TaskType.DOCUMENT_EXTRACTION:
return DocumentExtraction.from_v3_dict(document, model, review, prediction)
elif model.task_type == TaskType.FORM_EXTRACTION:
return FormExtraction.from_v3_dict(document, model, review, prediction)
elif model.task_type == TaskType.UNBUNDLING:
Expand Down
23 changes: 0 additions & 23 deletions indico_toolkit/results/predictions/autoreviewable.py

This file was deleted.

Loading

0 comments on commit 2f97bb4

Please sign in to comment.