Skip to content

Commit

Permalink
add "sum" to default aggregation functions for LabelCountCollector (#354
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ArneBinder authored Sep 27, 2023
1 parent d38c796 commit 3f07b1c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/pytorch_ie/metrics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class LabelCountCollector(DocumentStatistic):
{("ORG",): [2, 3], ("LOC",): [2]} -> {("ORG",): [2, 3], ("LOC",): [2, 0]}
"""

DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len"]
DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len", "sum"]

def __init__(
self, field: str, labels: Union[List[str], str], label_attribute: str = "label", **kwargs
Expand Down
102 changes: 64 additions & 38 deletions tests/core/test_statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,72 +37,98 @@ def test_statistics(dataset):
statistic = LabelCountCollector(field="entities", labels=["LOC", "PER", "ORG", "MISC"])
values = statistic(dataset)
assert values == {
"test": {
"LOC": {"len": 3, "max": 2, "mean": 1.0, "min": 0, "std": 0.816496580927726},
"MISC": {"len": 3, "max": 0, "mean": 0.0, "min": 0, "std": 0.0},
"ORG": {"len": 3, "max": 0, "mean": 0.0, "min": 0, "std": 0.0},
"PER": {
"len": 3,
"max": 1,
"mean": 0.6666666666666666,
"min": 0,
"std": 0.4714045207910317,
},
},
"train": {
"LOC": {
"len": 3,
"max": 1,
"mean": 0.3333333333333333,
"min": 0,
"std": 0.4714045207910317,
},
"MISC": {
"min": 0,
"max": 1,
"len": 3,
"max": 2,
"mean": 0.6666666666666666,
"sum": 1,
},
"PER": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"std": 0.9428090415820634,
"max": 1,
"len": 3,
"sum": 1,
},
"ORG": {
"len": 3,
"max": 1,
"mean": 0.3333333333333333,
"min": 0,
"std": 0.4714045207910317,
},
"PER": {
"len": 3,
"min": 0,
"max": 1,
"mean": 0.3333333333333333,
"len": 3,
"sum": 1,
},
"MISC": {
"mean": 0.6666666666666666,
"std": 0.9428090415820634,
"min": 0,
"std": 0.4714045207910317,
"max": 2,
"len": 3,
"sum": 2,
},
},
"validation": {
"LOC": {
"len": 3,
"max": 1,
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"PER": {
"mean": 0.3333333333333333,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"ORG": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3},
"MISC": {
"len": 3,
"max": 1,
"mean": 0.3333333333333333,
"min": 0,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 1,
},
"ORG": {"len": 3, "max": 2, "mean": 1.0, "min": 0, "std": 0.816496580927726},
},
"test": {
"LOC": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3},
"PER": {
"len": 3,
"max": 1,
"mean": 0.3333333333333333,
"min": 0,
"mean": 0.6666666666666666,
"std": 0.4714045207910317,
"min": 0,
"max": 1,
"len": 3,
"sum": 2,
},
"ORG": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0},
"MISC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0},
},
}

statistic = LabelCountCollector(field="entities", labels="INFERRED")
values = statistic(dataset)
assert values == {
"train": {
"ORG": {"max": 1, "len": 1, "sum": 1},
"MISC": {"max": 2, "len": 1, "sum": 2},
"PER": {"max": 1, "len": 1, "sum": 1},
"LOC": {"max": 1, "len": 1, "sum": 1},
},
"validation": {
"ORG": {"max": 2, "len": 2, "sum": 3},
"LOC": {"max": 1, "len": 1, "sum": 1},
"MISC": {"max": 1, "len": 1, "sum": 1},
"PER": {"max": 1, "len": 1, "sum": 1},
},
"test": {"LOC": {"max": 2, "len": 2, "sum": 3}, "PER": {"max": 1, "len": 2, "sum": 2}},
}

statistic = FieldLengthCollector(field="text")
Expand Down

0 comments on commit 3f07b1c

Please sign in to comment.