From e358c55186ba886bf01e76dce128fa4e1e28b2bb Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Wed, 8 Nov 2023 11:50:17 +0100 Subject: [PATCH] re-add some statistics (#367) * re-add statistics: TokenCountCollector, FieldLengthCollector, SubFieldLengthCollector, DummyCollector, LabelCountCollector * revert: change document_dataset fixture scope * fix document_dataset fixture * cleanup tests --- src/pytorch_ie/metrics/statistics.py | 135 +++++++++++++++++++++++++++ tests/metrics/test_statistics.py | 87 +++++++++++++++++ 2 files changed, 222 insertions(+) create mode 100644 src/pytorch_ie/metrics/statistics.py create mode 100644 tests/metrics/test_statistics.py diff --git a/src/pytorch_ie/metrics/statistics.py b/src/pytorch_ie/metrics/statistics.py new file mode 100644 index 00000000..42d08893 --- /dev/null +++ b/src/pytorch_ie/metrics/statistics.py @@ -0,0 +1,135 @@ +import logging +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Type, Union + +from transformers import AutoTokenizer, PreTrainedTokenizer + +from pytorch_ie.core import Document, DocumentStatistic +from pytorch_ie.documents import TextBasedDocument + +logger = logging.getLogger(__name__) + + +class TokenCountCollector(DocumentStatistic): + """Collects the token count of a field when tokenizing its content with a Huggingface + tokenizer. + + The content of the field should be a string. + """ + + def __init__( + self, + tokenizer: Union[str, PreTrainedTokenizer], + text_field: str = "text", + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + document_type: Optional[Type[Document]] = None, + **kwargs, + ): + if document_type is None and text_field == "text": + document_type = TextBasedDocument + super().__init__(document_type=document_type, **kwargs) + self.tokenizer = ( + AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer + ) + self.tokenizer_kwargs = tokenizer_kwargs or {} + self.text_field = text_field + + def _collect(self, doc: Document) -> int: + text = getattr(doc, self.text_field) + encodings = self.tokenizer(text, **self.tokenizer_kwargs) + tokens = encodings.tokens() + return len(tokens) + + +class FieldLengthCollector(DocumentStatistic): + """Collects the length of a field, e.g. to collect the number the characters in the input text. + + The field should be a list of sized elements. + """ + + def __init__(self, field: str, **kwargs): + super().__init__(**kwargs) + self.field = field + + def _collect(self, doc: Document) -> int: + field_obj = getattr(doc, self.field) + return len(field_obj) + + +class SubFieldLengthCollector(DocumentStatistic): + """Collects the length of a subfield in a field, e.g. to collect the number of arguments of + N-ary relations.""" + + def __init__(self, field: str, subfield: str, **kwargs): + super().__init__(**kwargs) + self.field = field + self.subfield = subfield + + def _collect(self, doc: Document) -> List[int]: + field_obj = getattr(doc, self.field) + lengths = [] + for entry in field_obj: + subfield_obj = getattr(entry, self.subfield) + lengths.append(len(subfield_obj)) + return lengths + + +class DummyCollector(DocumentStatistic): + """A dummy collector that always returns 1, e.g. to count the number of documents. + + Can be used to count the number of documents. + """ + + DEFAULT_AGGREGATION_FUNCTIONS = ["sum"] + + def _collect(self, doc: Document) -> int: + return 1 + + +class LabelCountCollector(DocumentStatistic): + """Collects the number of field entries per label, e.g. to collect the number of entities per + type. + + The field should be a list of elements with a label attribute. + + Important: To make correct use of the result data, missing values need to be filled with 0, e.g.: + {("ORG",): [2, 3], ("LOC",): [2]} -> {("ORG",): [2, 3], ("LOC",): [2, 0]} + """ + + DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len", "sum"] + + def __init__( + self, field: str, labels: Union[List[str], str], label_attribute: str = "label", **kwargs + ): + super().__init__(**kwargs) + self.field = field + self.label_attribute = label_attribute + if not (isinstance(labels, list) or labels == "INFERRED"): + raise ValueError("labels must be a list of strings or 'INFERRED'") + if labels == "INFERRED": + logger.warning( + f"Inferring labels with {self.__class__.__name__} from data produces wrong results " + f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " + f"are not included in the calculation. We remove these aggregation functions from " + f"this collector, but be aware that the results may be wrong for your own aggregation " + f"functions that rely on zero values." + ) + self.aggregation_functions: Dict[str, Callable[[List], Any]] = { + name: func + for name, func in self.aggregation_functions.items() + if name not in ["mean", "std", "min"] + } + + self.labels = labels + + def _collect(self, doc: Document) -> Dict[str, int]: + field_obj = getattr(doc, self.field) + counts: Dict[str, int] + if self.labels == "INFERRED": + counts = defaultdict(int) + else: + counts = {label: 0 for label in self.labels} + for elem in field_obj: + label = getattr(elem, self.label_attribute) + counts[label] += 1 + return dict(counts) diff --git a/tests/metrics/test_statistics.py b/tests/metrics/test_statistics.py new file mode 100644 index 00000000..8035b462 --- /dev/null +++ b/tests/metrics/test_statistics.py @@ -0,0 +1,87 @@ +from pytorch_ie.metrics.statistics import ( + DummyCollector, + FieldLengthCollector, + LabelCountCollector, + SubFieldLengthCollector, + TokenCountCollector, +) + + +def test_statistics(document_dataset): + statistic = DummyCollector() + values = statistic(document_dataset) + assert values == {"test": {"sum": 2}, "train": {"sum": 8}, "val": {"sum": 2}} + + # note that we check for labels=["LOC", "PER", "ORG"], but the actual labels in the data are just ["PER", "ORG"] + statistic = LabelCountCollector(field="entities", labels=["LOC", "PER", "ORG"]) + values = statistic(document_dataset) + assert values == { + "test": { + "LOC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 2, "sum": 0}, + "PER": {"mean": 0.5, "std": 0.5, "min": 0, "max": 1, "len": 2, "sum": 1}, + "ORG": {"mean": 1.0, "std": 1.0, "min": 0, "max": 2, "len": 2, "sum": 2}, + }, + "val": { + "LOC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 2, "sum": 0}, + "PER": {"mean": 0.5, "std": 0.5, "min": 0, "max": 1, "len": 2, "sum": 1}, + "ORG": {"mean": 1.0, "std": 1.0, "min": 0, "max": 2, "len": 2, "sum": 2}, + }, + "train": { + "LOC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 8, "sum": 0}, + "PER": { + "mean": 0.875, + "std": 0.5994789404140899, + "min": 0, + "max": 2, + "len": 8, + "sum": 7, + }, + "ORG": { + "mean": 1.125, + "std": 0.7806247497997998, + "min": 0, + "max": 2, + "len": 8, + "sum": 9, + }, + }, + } + + statistic = LabelCountCollector(field="entities", labels="INFERRED") + values = statistic(document_dataset) + assert values == { + "test": {"PER": {"max": 1, "len": 1, "sum": 1}, "ORG": {"max": 2, "len": 1, "sum": 2}}, + "val": {"PER": {"max": 1, "len": 1, "sum": 1}, "ORG": {"max": 2, "len": 1, "sum": 2}}, + "train": {"PER": {"max": 2, "len": 6, "sum": 7}, "ORG": {"max": 2, "len": 6, "sum": 9}}, + } + + statistic = FieldLengthCollector(field="text") + values = statistic(document_dataset) + assert values == { + "test": {"max": 51, "mean": 34.5, "min": 18, "std": 16.5}, + "train": {"max": 54, "mean": 28.25, "min": 15, "std": 14.694812009685595}, + "val": {"max": 51, "mean": 34.5, "min": 18, "std": 16.5}, + } + + # this is not super useful, we just collect the lengths of the labels, but it is enough to test the code + statistic = SubFieldLengthCollector(field="entities", subfield="label") + values = statistic(document_dataset) + assert values == { + "test": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0}, + "train": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0}, + "val": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0}, + } + + +def test_statistics_with_tokenize(document_dataset): + statistic = TokenCountCollector( + text_field="text", + tokenizer="bert-base-uncased", + tokenizer_kwargs=dict(add_special_tokens=False), + ) + values = statistic(document_dataset) + assert values == { + "test": {"max": 13, "mean": 8.5, "min": 4, "std": 4.5}, + "train": {"max": 14, "mean": 7.75, "min": 4, "std": 3.6314597615834874}, + "val": {"max": 13, "mean": 8.5, "min": 4, "std": 4.5}, + }