Skip to content

Commit

Permalink
Merge pull request #159 from IndicoDataSolutions/unbundling_metrics
Browse files Browse the repository at this point in the history
adding unbundling metrics support
  • Loading branch information
Scott771 authored Oct 2, 2023
2 parents f140b54 + bf726d1 commit 7970e80
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
126 changes: 126 additions & 0 deletions indico_toolkit/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@


class ExtractionMetrics(IndicoWrapper):
"""
Example usage:
metrics = ExtractionMetrics(client)
metrics.get_extraction_metrics(MODEL_GROUP_ID)
# get a pandas dataframe of field level results
df = metrics.get_metrics_df()
print(df.head())
# get metrics for a specific span type and/or model ID
df = metrics.get_metrics_df(span_type="exact", select_model_id=102)
print(df.head())
# write the results to a CSV (can also optionally pass span_type/model ID here as well)
metrics.to_csv("./my_metrics.pdf")
# get an interactive bar plot to visualize model improvement over time
metrics.bar_plot("./my_bar_plot.html")
"""
def __init__(self, client: IndicoClient):
self.client = client
self.raw_metrics: List[dict] = None
Expand Down Expand Up @@ -169,6 +189,112 @@ def to_csv(
df.to_csv(output_path, index=False)



class UnbundlingMetrics(ExtractionMetrics):
"""
Example Usage:
um = UnbundlingMetrics(client)
um.get_metrics(1232)
um.line_plot("./my_metric_plot.html", metric="recall", title="Insurance Model Recall Improvement")
"""
def get_metrics(self, model_group_id: int):
"""
Collect all metrics available based on a Model Group ID for an Unbundling model
Args:
model_group_id (int): Model Group ID that you're interestd in (available within the Explain UI)
"""
results = self.graphQL_request(METRIC_QUERY, {"modelGroupId": model_group_id})
if len(results["modelGroups"]["modelGroups"]) == 0:
raise ToolkitInputError(
f"There are no models associated with ID: {model_group_id}"
)
results = results["modelGroups"]["modelGroups"][0]["models"]
raw_metrics = []
included_models = []
labeled_samples = []
for r in results:
model_info = json.loads(r["modelInfo"])
if "total_number_of_examples" not in model_info or "metrics" not in model_info:
# some dictionaries don't come back with required fields...
continue
labeled_samples.append(model_info["total_number_of_examples"])
included_models.append(r["id"])
raw_metrics.append(model_info["metrics"]["per_class_metrics"])
self.raw_metrics = raw_metrics
self.included_models = included_models
self.number_of_samples = {model_id:samples for model_id, samples in zip(included_models, labeled_samples)}


def get_metrics_df(self) -> pd.DataFrame:
cleaned_metrics = []
for model_id, metrics in zip(self.included_models, self.raw_metrics):
for class_name in metrics:
scores = metrics[class_name]["page"]
scores["field_name"] = class_name
scores["model_id"] = model_id
scores["number_of_samples"] = self.number_of_samples[model_id]
cleaned_metrics.append(scores)
df = pd.DataFrame(cleaned_metrics)
return df.sort_values(by=["field_name", "model_id"], ascending=False)


def line_plot(
self,
output_path: str,
metric: str = "f1_score",
plot_title: str = "",
ids_to_exclude: List[int] = [],
fields_to_exclude: List[str] = [],
):
"""
Write an html line plot to disc with # of samples on x-axis, a metric on the y-axis and
each line representing a distinct field.
Will also open the plot automatically in your browser, where you will interactive
functionality and the ability to download a copy as a PNG as well.
Args:
output_path (str): where you want to write plot, e.g. "./myplot.html"
span_type (str): options include 'superset', 'exact', 'overlap' or 'token'
metric (str, optional): possible values are 'precision', 'recall', 'f1_score', 'false_positives',
'false_negatives', 'true_positives'. Defaults to "f1_score".
plot_title (str, optional): Title of the plot. Defaults to "".
ids_to_exclude (List[int], optional): Model Ids to exclude from plot.
fields_to_exclude (List[str], optional): Field Names to exclude from plot.
"""
df = self.get_metrics_df()
if ids_to_exclude:
df = df.drop(df.loc[df["model_id"].isin(ids_to_exclude)].index)
if fields_to_exclude:
df = df.drop(df.loc[df["field_name"].isin(fields_to_exclude)].index)
df = df.sort_values(by=["field_name", "number_of_samples", metric])
plotting = Plotting()
for field in sorted(df["field_name"].unique()):
sub_df = df.loc[df["field_name"] == field].copy()
# ensure only one value per # of samples
sub_df = sub_df.drop_duplicates(subset=["number_of_samples"])
plotting.add_line_data(
sub_df["number_of_samples"],
sub_df[metric],
name=field,
color=None,
)
plotting.define_layout(
xaxis_title="Number of Samples",
legend_title="Field",
plot_title=plot_title,
yaxis_title=metric,
)
plotting.plot(output_path)

def bar_plot(self):
raise NotImplementedError("Bar Plot is not currently implemented for unbundling")

def get_extraction_metrics(self, model_group_id: int):
raise NotImplementedError("Not available for unbundling class")


METRIC_QUERY = """
query modelGroupMetrics($modelGroupId: Int!){
modelGroups(
Expand Down
3 changes: 3 additions & 0 deletions tests/indico_wrapper/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Test Datasets class methods
"""
import pytest
from indico_toolkit.indico_wrapper import Datasets
from indico.types import Dataset
Expand Down

2 comments on commit 7970e80

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test Coverage

Indico Toolkit Coverage Report
FileStmtsMissCoverMissing
indico_toolkit
   errors.py14193%22
indico_toolkit/association
   association.py35294%20, 40
   extracted_tokens.py58198%88
   line_items.py92397%151–152, 164
   positioning.py118298%231, 251
   split_merged_values.py24196%49
indico_toolkit/auto_class
   classifier.py77791%62, 65–69, 151
indico_toolkit/auto_review
   auto_reviewer.py24292%65–66
   review_config.py15287%32, 36
indico_toolkit/highlighter
   highlighter.py1291191%39, 50, 131, 208, 242–248
indico_toolkit/indico_wrapper
   dataset.py34779%42–45, 56, 98–99
   doc_extraction.py34197%53
   download.py50394%46, 104, 167
   indico_wrapper.py32584%57–59, 109, 112
   reviewer.py27485%42–43, 51–52
   workflow.py791976%42, 76, 88–93, 137–142, 144, 149, 207–210
indico_toolkit/metrics
   compare_ground_truth.py66494%30, 37, 92, 94
   compare_models.py631084%57, 102–114, 125, 128, 134
   metrics.py1187041%42, 68, 109–134, 160–183, 215–237, 243–252, 277–300, 303, 308
   plotting.py15287%66, 80
indico_toolkit/ocr
   customocr_object.py23387%25, 29, 41
   ondoc_object.py41295%81, 92
indico_toolkit/pipelines
   file_processing.py90397%66, 70, 106
   pdf_manipulation.py33488%16–18, 63
indico_toolkit/snapshots
   snapshot.py1551690%92, 147–148, 185, 263, 281, 284–288, 295–296, 302–303, 307–308
indico_toolkit/staggered_loop
   metrics.py4784398%25–27, 41–45, 59–62, 76–79, 93–96, 114–115, 132–151, 172–222, 250–290, 308, 326–328, 347–349, 369–387, 403–407, 436–450, 474–488, 518–530, 552–565, 597–609, 633–637, 719–821, 839, 862–868, 888–902, 922, 947–958, 984–1002, 1024–1029, 1053–1060, 1101–1226, 1273–1433
   staggered_loop.py23916133%76–78, 81–104, 107–146, 157–173, 194–221, 248–290, 312–325, 378–436, 450–503, 525–532, 535–549, 565–654, 663–679
indico_toolkit/structure
   create_structure.py58580%1–206
   utils.py990%1–13
indico_toolkit/types
   classification.py43198%75
   extractions.py115497%151, 166, 169, 179
   workflow_object.py64789%29, 86, 90, 94, 98, 102, 106
TOTAL260686467% 

Tests Skipped Failures Errors Time
238 0 💤 0 ❌ 0 🔥 3m 23s ⏱️

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test Coverage

Indico Toolkit Coverage Report
FileStmtsMissCoverMissing
indico_toolkit
   errors.py14193%22
indico_toolkit/association
   association.py35294%20, 40
   extracted_tokens.py58198%88
   line_items.py92397%151–152, 164
   positioning.py118298%231, 251
   split_merged_values.py24196%49
indico_toolkit/auto_class
   classifier.py77791%62, 65–69, 151
indico_toolkit/auto_review
   auto_reviewer.py24292%65–66
   review_config.py15287%32, 36
indico_toolkit/highlighter
   highlighter.py1291191%39, 50, 131, 208, 242–248
indico_toolkit/indico_wrapper
   dataset.py34779%42–45, 56, 98–99
   doc_extraction.py34197%53
   download.py50394%46, 104, 167
   indico_wrapper.py32584%57–59, 109, 112
   reviewer.py27485%42–43, 51–52
   workflow.py791976%42, 76, 88–93, 137–142, 144, 149, 207–210
indico_toolkit/metrics
   compare_ground_truth.py66494%30, 37, 92, 94
   compare_models.py631084%57, 102–114, 125, 128, 134
   metrics.py1187041%42, 68, 109–134, 160–183, 215–237, 243–252, 277–300, 303, 308
   plotting.py15287%66, 80
indico_toolkit/ocr
   customocr_object.py23387%25, 29, 41
   ondoc_object.py41295%81, 92
indico_toolkit/pipelines
   file_processing.py90397%66, 70, 106
   pdf_manipulation.py33488%16–18, 63
indico_toolkit/snapshots
   snapshot.py1551690%92, 147–148, 185, 263, 281, 284–288, 295–296, 302–303, 307–308
indico_toolkit/staggered_loop
   metrics.py4784398%25–27, 41–45, 59–62, 76–79, 93–96, 114–115, 132–151, 172–222, 250–290, 308, 326–328, 347–349, 369–387, 403–407, 436–450, 474–488, 518–530, 552–565, 597–609, 633–637, 719–821, 839, 862–868, 888–902, 922, 947–958, 984–1002, 1024–1029, 1053–1060, 1101–1226, 1273–1433
   staggered_loop.py23916133%76–78, 81–104, 107–146, 157–173, 194–221, 248–290, 312–325, 378–436, 450–503, 525–532, 535–549, 565–654, 663–679
indico_toolkit/structure
   create_structure.py58580%1–206
   utils.py990%1–13
indico_toolkit/types
   classification.py43198%75
   extractions.py115497%151, 166, 169, 179
   workflow_object.py64789%29, 86, 90, 94, 98, 102, 106
TOTAL260686467% 

Tests Skipped Failures Errors Time
238 0 💤 0 ❌ 0 🔥 3m 13s ⏱️

Please sign in to comment.