Skip to content

Commit

Permalink
Refactor Auto Review (#169)
Browse files Browse the repository at this point in the history
* auto review refactor first pass

* update autoreviewfunction class

* update ar tests

* address comments

---------

Co-authored-by: Nathanael Shim <[email protected]>
  • Loading branch information
nateshim-indico and Nathanael Shim authored Oct 13, 2023
1 parent cc7b653 commit beb9812
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 181 deletions.
33 changes: 20 additions & 13 deletions examples/auto_review_predictions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""
Submit documents to a workflow, auto review them and submit them for human review
"""
from indico_toolkit import auto_review
from indico_toolkit.auto_review import (
ReviewConfiguration,
AutoReviewFunction,
AutoReviewer,
)
from indico_toolkit.auto_review.auto_review_functions import (
remove_by_confidence,
accept_by_confidence
)
from indico_toolkit.indico_wrapper import Workflow
from indico_toolkit import create_client

Expand All @@ -26,17 +29,21 @@
wf_results = wflow.get_submission_results_from_ids(submission_ids)
predictions = wf_results[0].predictions.to_list()

# Set up reviewer and review predictions
review_config = ReviewConfiguration(
field_config=[
{"function": "remove_by_confidence", "kwargs": {"conf_threshold": 0.90}},
{
"function": "accept_by_confidence",
"kwargs": {"conf_threshold": 0.98, "labels": ["Name", "Amount"]},
},
]
)
auto_reviewer = AutoReviewer(predictions, review_config)
# Set up custom review function
def custom_function(predictions, labels: list = None, match_text: str = ""):
for pred in predictions:
if pred["text"] == match_text:
pred["accepted"] = True
return predictions


# Set up review functions and review predictions
functions = [
AutoReviewFunction(remove_by_confidence, kwargs={"conf_threshold": 0.90}), # will default to all labels if labels is not provided
AutoReviewFunction(accept_by_confidence, labels=["Name", "Amount"]),
AutoReviewFunction(custom_function, kwargs={"match_text": "text to match"}) # call custom auto review function
]
auto_reviewer = AutoReviewer(predictions, functions)
auto_reviewer.apply_reviews()

# Submit review
Expand Down
3 changes: 1 addition & 2 deletions indico_toolkit/auto_review/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .review_config import ReviewConfiguration
from .auto_reviewer import AutoReviewer
from .auto_reviewer import AutoReviewer, AutoReviewFunction
83 changes: 38 additions & 45 deletions indico_toolkit/auto_review/auto_reviewer.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
from collections import defaultdict
from typing import List
from typing import Dict, List, Callable

from .review_config import ReviewConfiguration
from .auto_review_functions import (
accept_by_confidence,
reject_by_confidence,
reject_by_min_character_length,
reject_by_max_character_length,
accept_by_all_match_and_confidence,
remove_by_confidence,
)
class AutoReviewFunction:
"""
Class for hosting functions to manipulate predictions before sending to
auto review
Args:
function (Callable): method to be invoked when applying reviews.
The Callable must have the following arguments in the following order:
predictions (List[dict]),
labels (List[str]),
**kwargs,
labels (List[str]): list of labels to invoke method on. Defaults to all labels
kwargs (Dict[str, str]): dictionary containing additional arguments needed in calling function
"""
def __init__(
self,
function: Callable,
labels: List[str] = [],
kwargs: Dict[str, str] = {},
):
self.function = function
self.labels = labels
self.kwargs = kwargs

def apply(self, predictions: List[dict]):
if predictions and not self.labels:
self.labels = list(set([pred["label"] for pred in predictions]))
return self.function(predictions, self.labels, **self.kwargs)

REVIEWERS = {
"accept_by_confidence": accept_by_confidence,
"reject_by_confidence": reject_by_confidence,
"reject_by_min_character_length": reject_by_min_character_length,
"reject_by_max_character_length": reject_by_max_character_length,
"accept_by_all_match_and_confidence": accept_by_all_match_and_confidence,
"remove_by_confidence": remove_by_confidence,
}


class AutoReviewer:
Expand All @@ -29,42 +39,25 @@ class AutoReviewer:
Example Usage:
reviewer = AutoReviewer(
predictions, review_config
predictions, functions
)
reviewer.apply_review()
# Get your updated predictions
updated_predictions: List[dict] = reviewer.updated_predictions
updated_predictions: List[dict] = reviewer.apply_reviews()
"""

def __init__(
self,
predictions: List[dict],
review_config: ReviewConfiguration,
functions: List[AutoReviewFunction] = []
):
self.field_config = review_config.field_config
self.reviewers = self.add_reviewers(review_config.custom_functions)
self.predictions = predictions
self.updated_predictions = predictions
self.functions = functions

@staticmethod
def add_reviewers(custom_functions):
"""
Add custom functions into reviewers
Overwrites any default reviewers if function names match
"""
for func_name, func in custom_functions.items():
REVIEWERS[func_name] = func
return REVIEWERS
def apply_reviews(self) -> list:
for function in self.functions:
self.updated_predictions = function.apply(self.updated_predictions)
return self.updated_predictions


def apply_reviews(self):
for fn_config in self.field_config:
fn_name = fn_config["function"]
try:
review_fn = REVIEWERS[fn_name]
except KeyError:
raise KeyError(
f"{fn_name} function was not found, did you specify it in FieldConfiguration?"
)
kwargs = fn_config["kwargs"] if fn_config.get("kwargs") else {}
self.updated_predictions = review_fn(self.updated_predictions, **kwargs)
37 changes: 0 additions & 37 deletions indico_toolkit/auto_review/review_config.py

This file was deleted.

91 changes: 64 additions & 27 deletions tests/auto_review/test_auto_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
from collections import defaultdict
from indico.queries import Job
from indico_toolkit.indico_wrapper import Workflow
from indico_toolkit.auto_review import ReviewConfiguration, AutoReviewer
from indico_toolkit.auto_review import AutoReviewFunction, AutoReviewer
from indico_toolkit.auto_review.auto_review_functions import (
accept_by_confidence,
reject_by_confidence,
accept_by_all_match_and_confidence,
remove_by_confidence,
reject_by_max_character_length,
reject_by_min_character_length
)
from tests.conftest import FILE_PATH


Expand All @@ -22,15 +30,6 @@ def auto_review_preds(testdir_file_path):
return preds


@pytest.fixture(scope="session")
def auto_review_field_config(testdir_file_path):
with open(
os.path.join(testdir_file_path, "data/auto_review/field_config.json"), "r"
) as f:
field_config = json.load(f)
return field_config


@pytest.fixture(scope="function")
def id_pending_scripted(workflow_id, indico_client, pdf_filepath):
"""
Expand Down Expand Up @@ -64,18 +63,14 @@ def test_submit_auto_review(indico_client, id_pending_scripted, model_name):
result = wflow.get_submission_results_from_ids([id_pending_scripted])[0]
predictions = result.predictions.to_list()
# Review the submission
field_config = [
{"function": "accept_by_confidence", "kwargs": {"conf_threshold": 0.99}},
{
"function": "reject_by_min_character_length",
"kwargs": {
"min_length_threshold": 3,
"labels": ["Liability Amount", "Date of Appointment"],
},
},
functions = [
AutoReviewFunction(accept_by_confidence, kwargs={"conf_threshold": 0.99}),
AutoReviewFunction(
reject_by_min_character_length,
labels=["Liability Amount", "Date of Appointment"],
kwargs={"min_length_threshold": 3}),
]
review_config = ReviewConfiguration(field_config)
reviewer = AutoReviewer(predictions, review_config)
reviewer = AutoReviewer(predictions, functions)
reviewer.apply_reviews()
non_rejected_pred_count = len([i for i in reviewer.updated_predictions if "rejected" not in i])
wflow.submit_submission_review(
Expand All @@ -85,8 +80,7 @@ def test_submit_auto_review(indico_client, id_pending_scripted, model_name):
assert result.post_review_predictions.num_predictions == non_rejected_pred_count


def accept_if_match(predictions, match_text: str, labels: list = None):
"""Custom function to pass into ReviewConfiguration"""
def accept_if_match(predictions, labels: list = None, match_text: str = ""):
for pred in predictions:
if REJECTED not in pred:
if labels != None and pred["label"] not in labels:
Expand All @@ -107,10 +101,53 @@ def create_pred_label_map(predictions):
return prediction_label_map


def test_reviewer(auto_review_field_config, auto_review_preds):
custom_functions = {"accept_if_match": accept_if_match}
review_config = ReviewConfiguration(auto_review_field_config, custom_functions)
reviewer = AutoReviewer(auto_review_preds, review_config)
def test_reviewer(auto_review_preds):
custom_functions = [
AutoReviewFunction(
reject_by_confidence,
labels=["reject_by_confidence"],
kwargs={"conf_threshold": 0.7}
),
AutoReviewFunction(
accept_by_all_match_and_confidence,
labels = [
"accept_by_all_match_and_confidence",
"no_match_accept_by_all_match_and_confidence",
"low_conf_accept_by_all_match_and_confidence"
],
kwargs={"conf_threshold": 0.9}
),
AutoReviewFunction(
accept_by_confidence,
labels=[
"accept_by_confidence",
"reject_by_confidence"
],
kwargs={"conf_threshold": 0.8}
),
AutoReviewFunction(
remove_by_confidence,
labels=["remove_by_confidence"],
kwargs={"conf_threshold": 0.8}
),
AutoReviewFunction(
reject_by_min_character_length,
labels=["reject_by_min_character_length"],
kwargs={"min_length_threshold": 6}
),
AutoReviewFunction(
reject_by_max_character_length,
labels=["reject_by_max_character_length"],
kwargs={"max_length_threshold": 6}
),
AutoReviewFunction(
accept_if_match,
labels=["accept_if_match"],
kwargs={"match_text": "matching text"}
)
]

reviewer = AutoReviewer(auto_review_preds, custom_functions)
reviewer.apply_reviews()
preds = reviewer.updated_predictions
pred_map = create_pred_label_map(preds)
Expand Down
55 changes: 0 additions & 55 deletions tests/data/auto_review/field_config.json

This file was deleted.

4 changes: 2 additions & 2 deletions tests/data/auto_review/preds.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@
"label": "reject_by_max_character_length",
"text": "should be rejected",
"confidence": {
"reject_by_min_character_length": 0.83
"reject_by_max_character_length": 0.83
}
},
{
"start": 30,
"end": 40,
"label": "reject_by_max_character_length",
"label": "reject_by_min_character_length",
"text": "short",
"confidence": {
"reject_by_min_character_length": 0.74
Expand Down

2 comments on commit beb9812

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test Coverage

Indico Toolkit Coverage Report
FileStmtsMissCoverMissing
indico_toolkit
   errors.py17194%22
indico_toolkit/association
   association.py35294%20, 40
   extracted_tokens.py58198%88
   line_items.py92397%151–152, 164
   positioning.py118298%231, 251
   split_merged_values.py24196%49
indico_toolkit/auto_populate
   populator.py102991%220, 247, 262–268
   types.py38295%22, 32
indico_toolkit/highlighter
   highlighter.py1291191%39, 50, 131, 208, 242–248
indico_toolkit/indico_wrapper
   dataset.py33779%40–43, 57, 107–108
   doc_extraction.py34197%53
   download.py50394%46, 104, 167
   indico_wrapper.py32584%57–59, 109, 112
   reviewer.py27485%42–43, 51–52
   workflow.py791976%42, 76, 88–93, 137–142, 144, 149, 207–210
indico_toolkit/metrics
   compare_ground_truth.py66494%30, 37, 92, 94
   compare_models.py631084%57, 102–114, 125, 128, 134
   metrics.py1187041%42, 68, 109–134, 160–183, 215–237, 243–252, 277–300, 303, 308
   plotting.py15287%66, 80
indico_toolkit/ocr
   customocr_object.py23387%25, 29, 41
   ondoc_object.py41295%81, 92
indico_toolkit/pipelines
   file_processing.py90594%66, 70, 106, 110, 114
   pdf_manipulation.py33488%16–18, 63
indico_toolkit/snapshots
   snapshot.py1551690%92, 147–148, 185, 263, 281, 284–288, 295–296, 302–303, 307–308
indico_toolkit/staggered_loop
   metrics.py4784398%25–27, 41–45, 59–62, 76–79, 93–96, 114–115, 132–151, 172–222, 250–290, 308, 326–328, 347–349, 369–387, 403–407, 436–450, 474–488, 518–530, 552–565, 597–609, 633–637, 719–821, 839, 862–868, 888–902, 922, 947–958, 984–1002, 1024–1029, 1053–1060, 1101–1226, 1273–1433
   staggered_loop.py23916133%76–78, 81–104, 107–146, 157–173, 194–221, 248–290, 312–325, 378–436, 450–503, 525–532, 535–549, 565–654, 663–679
indico_toolkit/structure
   create_structure.py671775%66–70, 76–81, 114–130, 182
indico_toolkit/types
   classification.py43198%75
   extractions.py115497%151, 166, 169, 179
   workflow_object.py64789%29, 86, 90, 94, 98, 102, 106
TOTAL267881670% 

Tests Skipped Failures Errors Time
234 0 💤 1 ❌ 0 🔥 4m 7s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test Coverage

Indico Toolkit Coverage Report
FileStmtsMissCoverMissing
indico_toolkit
   errors.py17194%22
indico_toolkit/association
   association.py35294%20, 40
   extracted_tokens.py58198%88
   line_items.py92397%151–152, 164
   positioning.py118298%231, 251
   split_merged_values.py24196%49
indico_toolkit/auto_populate
   populator.py102991%220, 247, 262–268
   types.py38295%22, 32
indico_toolkit/highlighter
   highlighter.py1291191%39, 50, 131, 208, 242–248
indico_toolkit/indico_wrapper
   dataset.py33779%40–43, 57, 107–108
   doc_extraction.py34197%53
   download.py50394%46, 104, 167
   indico_wrapper.py32584%57–59, 109, 112
   reviewer.py27485%42–43, 51–52
   workflow.py791976%42, 76, 88–93, 137–142, 144, 149, 207–210
indico_toolkit/metrics
   compare_ground_truth.py66494%30, 37, 92, 94
   compare_models.py631084%57, 102–114, 125, 128, 134
   metrics.py1187041%42, 68, 109–134, 160–183, 215–237, 243–252, 277–300, 303, 308
   plotting.py15287%66, 80
indico_toolkit/ocr
   customocr_object.py23387%25, 29, 41
   ondoc_object.py41295%81, 92
indico_toolkit/pipelines
   file_processing.py90594%66, 70, 106, 110, 114
   pdf_manipulation.py33488%16–18, 63
indico_toolkit/snapshots
   snapshot.py1551690%92, 147–148, 185, 263, 281, 284–288, 295–296, 302–303, 307–308
indico_toolkit/staggered_loop
   metrics.py4784398%25–27, 41–45, 59–62, 76–79, 93–96, 114–115, 132–151, 172–222, 250–290, 308, 326–328, 347–349, 369–387, 403–407, 436–450, 474–488, 518–530, 552–565, 597–609, 633–637, 719–821, 839, 862–868, 888–902, 922, 947–958, 984–1002, 1024–1029, 1053–1060, 1101–1226, 1273–1433
   staggered_loop.py23916133%76–78, 81–104, 107–146, 157–173, 194–221, 248–290, 312–325, 378–436, 450–503, 525–532, 535–549, 565–654, 663–679
indico_toolkit/structure
   create_structure.py671775%66–70, 76–81, 114–130, 182
indico_toolkit/types
   classification.py43198%75
   extractions.py115497%151, 166, 169, 179
   workflow_object.py64789%29, 86, 90, 94, 98, 102, 106
TOTAL267881670% 

Tests Skipped Failures Errors Time
234 0 💤 1 ❌ 0 🔥 4m 9s ⏱️

Please sign in to comment.