diff --git a/src/pytorch_ie/metrics/statistics.py b/src/pytorch_ie/metrics/statistics.py index 5ff68eba..05d474c0 100644 --- a/src/pytorch_ie/metrics/statistics.py +++ b/src/pytorch_ie/metrics/statistics.py @@ -1,3 +1,4 @@ +import logging from collections import defaultdict from typing import Any, Dict, List, Optional, Type, Union @@ -6,6 +7,8 @@ from pytorch_ie.core import Document, DocumentStatistic from pytorch_ie.documents import TextBasedDocument +logger = logging.getLogger(__name__) + class TokenCountCollector(DocumentStatistic): """Collects the token count of a field when tokenizing its content with a Huggingface tokenizer. @@ -112,14 +115,38 @@ class LabelCountCollector(DocumentStatistic): DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len"] - def __init__(self, field: str, labels: List[str], **kwargs): + def __init__( + self, field: str, labels: Union[List[str], str], label_attribute: str = "label", **kwargs + ): super().__init__(**kwargs) self.field = field + self.label_attribute = label_attribute + if not (isinstance(labels, list) or labels == "INFERRED"): + raise ValueError("labels must be a list of strings or 'INFERRED'") + if labels == "INFERRED": + logger.warning( + f"Inferring labels with {self.__class__.__name__} from data produces wrong results " + f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " + f"are not included in the calculation. We remove these aggregation functions from " + f"this collector, but be aware that the results may be wrong for your own aggregation " + f"functions that rely on zero values." + ) + self.aggregation_functions = { + name: func + for name, func in self.aggregation_functions.items() + if name not in ["mean", "std", "min"] + } + self.labels = labels def _collect(self, doc: Document) -> Dict[str, int]: field_obj = getattr(doc, self.field) - counts: Dict[str, int] = {label: 0 for label in self.labels} + counts: Dict[str, int] + if self.labels == "INFERRED": + counts = defaultdict(int) + else: + counts = {label: 0 for label in self.labels} for elem in field_obj: - counts[elem.label] += 1 + label = getattr(elem, self.label_attribute) + counts[label] += 1 return dict(counts)