Skip to content

Commit

Permalink
log a warning and remove "mean", "std", "min" from aggregation_functi…
Browse files Browse the repository at this point in the history
…ons if labels are inferred from data
  • Loading branch information
ArneBinder committed Sep 27, 2023
1 parent 622134a commit e9f6829
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/pytorch_ie/metrics/statistics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Type, Union

Expand All @@ -6,6 +7,8 @@
from pytorch_ie.core import Document, DocumentStatistic
from pytorch_ie.documents import TextBasedDocument

logger = logging.getLogger(__name__)


class TokenCountCollector(DocumentStatistic):
"""Collects the token count of a field when tokenizing its content with a Huggingface tokenizer.
Expand Down Expand Up @@ -120,6 +123,18 @@ def __init__(
self.label_attribute = label_attribute
if not (isinstance(labels, list) or labels == "INFERRED"):
raise ValueError("labels must be a list of strings or 'INFERRED'")
if labels == "INFERRED":
logger.warning(
f"Inferring labels with {self.__class__.__name__} from data produces wrong results "
f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values "
f"are not included in the calculation. We remove these aggregation functions from "
f"this collector, but be aware that the results may be wrong for your own aggregation "
f"functions that rely on zero values."
)
self.aggregation_functions = [
func for func in self.aggregation_functions if func not in ["mean", "std", "min"]
]

self.labels = labels

def _collect(self, doc: Document) -> Dict[str, int]:
Expand Down

0 comments on commit e9f6829

Please sign in to comment.