Skip to content

Commit

Permalink
infer labels in LabelCountCollector (#351)
Browse files Browse the repository at this point in the history
* LabelCountCollector: allow labels = "INFERRED" and add label_attribute parameter

* log a warning and remove "mean", "std", "min" from aggregation_functions if labels are inferred from data

* log a warning and remove "mean", "std", "min" from aggregation_functions if labels are inferred from data

* make pre-commit happy
  • Loading branch information
ArneBinder authored Sep 27, 2023
1 parent 1529734 commit 131931a
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions src/pytorch_ie/metrics/statistics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Type, Union

Expand All @@ -6,6 +7,8 @@
from pytorch_ie.core import Document, DocumentStatistic
from pytorch_ie.documents import TextBasedDocument

logger = logging.getLogger(__name__)


class TokenCountCollector(DocumentStatistic):
"""Collects the token count of a field when tokenizing its content with a Huggingface tokenizer.
Expand Down Expand Up @@ -112,14 +115,38 @@ class LabelCountCollector(DocumentStatistic):

DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len"]

def __init__(self, field: str, labels: List[str], **kwargs):
def __init__(
self, field: str, labels: Union[List[str], str], label_attribute: str = "label", **kwargs
):
super().__init__(**kwargs)
self.field = field
self.label_attribute = label_attribute
if not (isinstance(labels, list) or labels == "INFERRED"):
raise ValueError("labels must be a list of strings or 'INFERRED'")
if labels == "INFERRED":
logger.warning(
f"Inferring labels with {self.__class__.__name__} from data produces wrong results "
f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values "
f"are not included in the calculation. We remove these aggregation functions from "
f"this collector, but be aware that the results may be wrong for your own aggregation "
f"functions that rely on zero values."
)
self.aggregation_functions = {
name: func
for name, func in self.aggregation_functions.items()
if name not in ["mean", "std", "min"]
}

self.labels = labels

def _collect(self, doc: Document) -> Dict[str, int]:
field_obj = getattr(doc, self.field)
counts: Dict[str, int] = {label: 0 for label in self.labels}
counts: Dict[str, int]
if self.labels == "INFERRED":
counts = defaultdict(int)
else:
counts = {label: 0 for label in self.labels}
for elem in field_obj:
counts[elem.label] += 1
label = getattr(elem, self.label_attribute)
counts[label] += 1
return dict(counts)

0 comments on commit 131931a

Please sign in to comment.