diff --git a/src/pytorch_ie/metrics/statistics.py b/src/pytorch_ie/metrics/statistics.py index a3cf86d9..53c869d9 100644 --- a/src/pytorch_ie/metrics/statistics.py +++ b/src/pytorch_ie/metrics/statistics.py @@ -1,3 +1,4 @@ +import logging from collections import defaultdict from typing import Any, Dict, List, Optional, Type, Union @@ -6,6 +7,8 @@ from pytorch_ie.core import Document, DocumentStatistic from pytorch_ie.documents import TextBasedDocument +logger = logging.getLogger(__name__) + class TokenCountCollector(DocumentStatistic): """Collects the token count of a field when tokenizing its content with a Huggingface tokenizer. @@ -120,6 +123,18 @@ def __init__( self.label_attribute = label_attribute if not (isinstance(labels, list) or labels == "INFERRED"): raise ValueError("labels must be a list of strings or 'INFERRED'") + if labels == "INFERRED": + logger.warning( + f"Inferring labels with {self.__class__.__name__} from data produces wrong results " + f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " + f"are not included in the calculation. We remove these aggregation functions from " + f"this collector, but be aware that the results may be wrong for your own aggregation " + f"functions that rely on zero values." + ) + self.aggregation_functions = [ + func for func in self.aggregation_functions if func not in ["mean", "std", "min"] + ] + self.labels = labels def _collect(self, doc: Document) -> Dict[str, int]: