diff --git a/examples/auto_review_predictions.py b/examples/auto_review_predictions.py index 6f5867bd..a9c11d77 100644 --- a/examples/auto_review_predictions.py +++ b/examples/auto_review_predictions.py @@ -1,11 +1,14 @@ """ Submit documents to a workflow, auto review them and submit them for human review """ -from indico_toolkit import auto_review from indico_toolkit.auto_review import ( - ReviewConfiguration, + AutoReviewFunction, AutoReviewer, ) +from indico_toolkit.auto_review.auto_review_functions import ( + remove_by_confidence, + accept_by_confidence +) from indico_toolkit.indico_wrapper import Workflow from indico_toolkit import create_client @@ -26,17 +29,21 @@ wf_results = wflow.get_submission_results_from_ids(submission_ids) predictions = wf_results[0].predictions.to_list() -# Set up reviewer and review predictions -review_config = ReviewConfiguration( - field_config=[ - {"function": "remove_by_confidence", "kwargs": {"conf_threshold": 0.90}}, - { - "function": "accept_by_confidence", - "kwargs": {"conf_threshold": 0.98, "labels": ["Name", "Amount"]}, - }, - ] -) -auto_reviewer = AutoReviewer(predictions, review_config) +# Set up custom review function +def custom_function(predictions, labels: list = None, match_text: str = ""): + for pred in predictions: + if pred["text"] == match_text: + pred["accepted"] = True + return predictions + + +# Set up review functions and review predictions +functions = [ + AutoReviewFunction(remove_by_confidence, kwargs={"conf_threshold": 0.90}), # will default to all labels if labels is not provided + AutoReviewFunction(accept_by_confidence, labels=["Name", "Amount"]), + AutoReviewFunction(custom_function, kwargs={"match_text": "text to match"}) # call custom auto review function +] +auto_reviewer = AutoReviewer(predictions, functions) auto_reviewer.apply_reviews() # Submit review diff --git a/indico_toolkit/auto_review/__init__.py b/indico_toolkit/auto_review/__init__.py index d0bed27d..f2524f90 100644 --- a/indico_toolkit/auto_review/__init__.py +++ b/indico_toolkit/auto_review/__init__.py @@ -1,2 +1 @@ -from .review_config import ReviewConfiguration -from .auto_reviewer import AutoReviewer +from .auto_reviewer import AutoReviewer, AutoReviewFunction diff --git a/indico_toolkit/auto_review/auto_reviewer.py b/indico_toolkit/auto_review/auto_reviewer.py index 36c16e8a..13fa067f 100644 --- a/indico_toolkit/auto_review/auto_reviewer.py +++ b/indico_toolkit/auto_review/auto_reviewer.py @@ -1,25 +1,35 @@ -from collections import defaultdict -from typing import List +from typing import Dict, List, Callable -from .review_config import ReviewConfiguration -from .auto_review_functions import ( - accept_by_confidence, - reject_by_confidence, - reject_by_min_character_length, - reject_by_max_character_length, - accept_by_all_match_and_confidence, - remove_by_confidence, -) +class AutoReviewFunction: + """ + Class for hosting functions to manipulate predictions before sending to + auto review + + Args: + function (Callable): method to be invoked when applying reviews. + The Callable must have the following arguments in the following order: + predictions (List[dict]), + labels (List[str]), + **kwargs, + + labels (List[str]): list of labels to invoke method on. Defaults to all labels + kwargs (Dict[str, str]): dictionary containing additional arguments needed in calling function + """ + def __init__( + self, + function: Callable, + labels: List[str] = [], + kwargs: Dict[str, str] = {}, + ): + self.function = function + self.labels = labels + self.kwargs = kwargs + def apply(self, predictions: List[dict]): + if predictions and not self.labels: + self.labels = list(set([pred["label"] for pred in predictions])) + return self.function(predictions, self.labels, **self.kwargs) -REVIEWERS = { - "accept_by_confidence": accept_by_confidence, - "reject_by_confidence": reject_by_confidence, - "reject_by_min_character_length": reject_by_min_character_length, - "reject_by_max_character_length": reject_by_max_character_length, - "accept_by_all_match_and_confidence": accept_by_all_match_and_confidence, - "remove_by_confidence": remove_by_confidence, -} class AutoReviewer: @@ -29,42 +39,25 @@ class AutoReviewer: Example Usage: reviewer = AutoReviewer( - predictions, review_config + predictions, functions ) - reviewer.apply_review() # Get your updated predictions - updated_predictions: List[dict] = reviewer.updated_predictions + updated_predictions: List[dict] = reviewer.apply_reviews() """ def __init__( self, predictions: List[dict], - review_config: ReviewConfiguration, + functions: List[AutoReviewFunction] = [] ): - self.field_config = review_config.field_config - self.reviewers = self.add_reviewers(review_config.custom_functions) self.predictions = predictions self.updated_predictions = predictions + self.functions = functions - @staticmethod - def add_reviewers(custom_functions): - """ - Add custom functions into reviewers - Overwrites any default reviewers if function names match - """ - for func_name, func in custom_functions.items(): - REVIEWERS[func_name] = func - return REVIEWERS + def apply_reviews(self) -> list: + for function in self.functions: + self.updated_predictions = function.apply(self.updated_predictions) + return self.updated_predictions + - def apply_reviews(self): - for fn_config in self.field_config: - fn_name = fn_config["function"] - try: - review_fn = REVIEWERS[fn_name] - except KeyError: - raise KeyError( - f"{fn_name} function was not found, did you specify it in FieldConfiguration?" - ) - kwargs = fn_config["kwargs"] if fn_config.get("kwargs") else {} - self.updated_predictions = review_fn(self.updated_predictions, **kwargs) diff --git a/indico_toolkit/auto_review/review_config.py b/indico_toolkit/auto_review/review_config.py deleted file mode 100644 index b3ac4b4d..00000000 --- a/indico_toolkit/auto_review/review_config.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Dict, List, Callable - - -REQUIRED_CONF_ARGS = ["function"] - - -class ReviewConfiguration: - def __init__( - self, field_config: List[dict], custom_functions: Dict[str, Callable] = {} - ): - """ - Args: - field_config (List[dict]): list of function config dictionaries. Available functions defined in auto_review_functions.py - function config: - { - "function": "reject_by_confidence", - "kwargs": { - "labels": ["Check Amount", "Name"], - "conf_threshold": 0.98 - }, - } - custom_functions (Dict[str, Callable]): Dictionary with custom functions to - use in auto-review - """ - self.custom_functions = custom_functions - self.field_config = self.validate_field_config(field_config) - - @staticmethod - def validate_field_config(field_config): - for function_config in field_config: - if not isinstance(function_config, dict): - raise TypeError(f"{function_config} value is not type dict") - config_keys = function_config.keys() - for key in REQUIRED_CONF_ARGS: - if key not in config_keys: - raise KeyError(f"{key} key missing from {function_config} config") - return field_config diff --git a/tests/auto_review/test_auto_review.py b/tests/auto_review/test_auto_review.py index 9c44abd6..b12ce21e 100644 --- a/tests/auto_review/test_auto_review.py +++ b/tests/auto_review/test_auto_review.py @@ -6,7 +6,15 @@ from collections import defaultdict from indico.queries import Job from indico_toolkit.indico_wrapper import Workflow -from indico_toolkit.auto_review import ReviewConfiguration, AutoReviewer +from indico_toolkit.auto_review import AutoReviewFunction, AutoReviewer +from indico_toolkit.auto_review.auto_review_functions import ( + accept_by_confidence, + reject_by_confidence, + accept_by_all_match_and_confidence, + remove_by_confidence, + reject_by_max_character_length, + reject_by_min_character_length +) from tests.conftest import FILE_PATH @@ -22,15 +30,6 @@ def auto_review_preds(testdir_file_path): return preds -@pytest.fixture(scope="session") -def auto_review_field_config(testdir_file_path): - with open( - os.path.join(testdir_file_path, "data/auto_review/field_config.json"), "r" - ) as f: - field_config = json.load(f) - return field_config - - @pytest.fixture(scope="function") def id_pending_scripted(workflow_id, indico_client, pdf_filepath): """ @@ -64,18 +63,14 @@ def test_submit_auto_review(indico_client, id_pending_scripted, model_name): result = wflow.get_submission_results_from_ids([id_pending_scripted])[0] predictions = result.predictions.to_list() # Review the submission - field_config = [ - {"function": "accept_by_confidence", "kwargs": {"conf_threshold": 0.99}}, - { - "function": "reject_by_min_character_length", - "kwargs": { - "min_length_threshold": 3, - "labels": ["Liability Amount", "Date of Appointment"], - }, - }, + functions = [ + AutoReviewFunction(accept_by_confidence, kwargs={"conf_threshold": 0.99}), + AutoReviewFunction( + reject_by_min_character_length, + labels=["Liability Amount", "Date of Appointment"], + kwargs={"min_length_threshold": 3}), ] - review_config = ReviewConfiguration(field_config) - reviewer = AutoReviewer(predictions, review_config) + reviewer = AutoReviewer(predictions, functions) reviewer.apply_reviews() non_rejected_pred_count = len([i for i in reviewer.updated_predictions if "rejected" not in i]) wflow.submit_submission_review( @@ -85,8 +80,7 @@ def test_submit_auto_review(indico_client, id_pending_scripted, model_name): assert result.post_review_predictions.num_predictions == non_rejected_pred_count -def accept_if_match(predictions, match_text: str, labels: list = None): - """Custom function to pass into ReviewConfiguration""" +def accept_if_match(predictions, labels: list = None, match_text: str = ""): for pred in predictions: if REJECTED not in pred: if labels != None and pred["label"] not in labels: @@ -107,10 +101,53 @@ def create_pred_label_map(predictions): return prediction_label_map -def test_reviewer(auto_review_field_config, auto_review_preds): - custom_functions = {"accept_if_match": accept_if_match} - review_config = ReviewConfiguration(auto_review_field_config, custom_functions) - reviewer = AutoReviewer(auto_review_preds, review_config) +def test_reviewer(auto_review_preds): + custom_functions = [ + AutoReviewFunction( + reject_by_confidence, + labels=["reject_by_confidence"], + kwargs={"conf_threshold": 0.7} + ), + AutoReviewFunction( + accept_by_all_match_and_confidence, + labels = [ + "accept_by_all_match_and_confidence", + "no_match_accept_by_all_match_and_confidence", + "low_conf_accept_by_all_match_and_confidence" + ], + kwargs={"conf_threshold": 0.9} + ), + AutoReviewFunction( + accept_by_confidence, + labels=[ + "accept_by_confidence", + "reject_by_confidence" + ], + kwargs={"conf_threshold": 0.8} + ), + AutoReviewFunction( + remove_by_confidence, + labels=["remove_by_confidence"], + kwargs={"conf_threshold": 0.8} + ), + AutoReviewFunction( + reject_by_min_character_length, + labels=["reject_by_min_character_length"], + kwargs={"min_length_threshold": 6} + ), + AutoReviewFunction( + reject_by_max_character_length, + labels=["reject_by_max_character_length"], + kwargs={"max_length_threshold": 6} + ), + AutoReviewFunction( + accept_if_match, + labels=["accept_if_match"], + kwargs={"match_text": "matching text"} + ) + ] + + reviewer = AutoReviewer(auto_review_preds, custom_functions) reviewer.apply_reviews() preds = reviewer.updated_predictions pred_map = create_pred_label_map(preds) diff --git a/tests/data/auto_review/field_config.json b/tests/data/auto_review/field_config.json deleted file mode 100644 index fe8cd598..00000000 --- a/tests/data/auto_review/field_config.json +++ /dev/null @@ -1,55 +0,0 @@ -[ - { - "function": "reject_by_confidence", - "kwargs": { - "labels": ["reject_by_confidence"], - "conf_threshold": 0.7 - } - }, - { - "function": "accept_by_all_match_and_confidence", - "kwargs": { - "labels": [ - "accept_by_all_match_and_confidence", - "no_match_accept_by_all_match_and_confidence", - "low_conf_accept_by_all_match_and_confidence" - ], - "conf_threshold": 0.9 - } - }, - { - "function": "accept_by_confidence", - "kwargs": { - "labels": ["accept_by_confidence", "reject_by_confidence"], - "conf_threshold": 0.8 - } - }, - { - "function": "remove_by_confidence", - "kwargs": { - "labels": ["remove_by_confidence"], - "conf_threshold": 0.8 - } - }, - { - "function": "reject_by_min_character_length", - "kwargs": { - "labels": ["reject_by_min_character_length"], - "min_length_threshold": 6 - } - }, - { - "function": "reject_by_max_character_length", - "kwargs": { - "labels": ["reject_by_max_character_length"], - "max_length_threshold": 6 - } - }, - { - "function": "accept_if_match", - "kwargs": { - "match_text": "matching text", - "labels": ["accept_if_match"] - } - } -] \ No newline at end of file diff --git a/tests/data/auto_review/preds.json b/tests/data/auto_review/preds.json index 239b52e0..cfde5c20 100644 --- a/tests/data/auto_review/preds.json +++ b/tests/data/auto_review/preds.json @@ -122,13 +122,13 @@ "label": "reject_by_max_character_length", "text": "should be rejected", "confidence": { - "reject_by_min_character_length": 0.83 + "reject_by_max_character_length": 0.83 } }, { "start": 30, "end": 40, - "label": "reject_by_max_character_length", + "label": "reject_by_min_character_length", "text": "short", "confidence": { "reject_by_min_character_length": 0.74