diff --git a/src/pie_datasets/__init__.py b/src/pie_datasets/__init__.py index 6c3d2550..2d1b01fc 100644 --- a/src/pie_datasets/__init__.py +++ b/src/pie_datasets/__init__.py @@ -1,22 +1,13 @@ -from .builder import GeneratorBasedBuilder -from .common import ( - EnterDatasetDictMixin, - EnterDatasetMixin, - ExitDatasetDictMixin, - ExitDatasetMixin, -) +from .builder import ArrowBasedBuilder, GeneratorBasedBuilder from .dataset import Dataset, IterableDataset from .dataset_dict import DatasetDict from .document_formatter import DocumentFormatter __all__ = [ "GeneratorBasedBuilder", + "ArrowBasedBuilder", "Dataset", "IterableDataset", "DatasetDict", "DocumentFormatter", - "EnterDatasetMixin", - "ExitDatasetMixin", - "EnterDatasetDictMixin", - "ExitDatasetDictMixin", ] diff --git a/src/pie_datasets/common.py b/src/pie_datasets/common.py deleted file mode 100644 index e3213b2e..00000000 --- a/src/pie_datasets/common.py +++ /dev/null @@ -1,40 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, Union - -from .dataset import Dataset, IterableDataset - - -class EnterDatasetMixin(ABC): - """Mixin for processors that enter a dataset context.""" - - @abstractmethod - def enter_dataset( - self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None - ) -> None: - """Enter dataset context.""" - - -class ExitDatasetMixin(ABC): - """Mixin for processors that exit a dataset context.""" - - @abstractmethod - def exit_dataset( - self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None - ) -> None: - """Exit dataset context.""" - - -class EnterDatasetDictMixin(ABC): - """Mixin for processors that enter a dataset dict context.""" - - @abstractmethod - def enter_dataset_dict(self, dataset_dict) -> None: - """Enter dataset dict context.""" - - -class ExitDatasetDictMixin(ABC): - """Mixin for processors that exit a dataset dict context.""" - - @abstractmethod - def exit_dataset_dict(self, dataset_dict) -> None: - """Exit dataset dict context.""" diff --git a/src/pie_datasets/dataset_dict.py b/src/pie_datasets/dataset_dict.py index ef0d5467..df2ff3c2 100644 --- a/src/pie_datasets/dataset_dict.py +++ b/src/pie_datasets/dataset_dict.py @@ -1,6 +1,7 @@ import json import logging import os +from abc import ABC, abstractmethod from pathlib import Path from typing import ( Any, @@ -15,16 +16,10 @@ ) import datasets -from pytorch_ie.core import Document +from pytorch_ie.core.document import Document from pytorch_ie.utils.hydra import resolve_target, serialize_document_type -from .common import ( - EnterDatasetDictMixin, - EnterDatasetMixin, - ExitDatasetDictMixin, - ExitDatasetMixin, -) -from .dataset import Dataset, IterableDataset, get_pie_dataset_type +from pie_datasets.dataset import Dataset, IterableDataset, get_pie_dataset_type logger = logging.getLogger(__name__) @@ -34,6 +29,42 @@ D = TypeVar("D", bound=Document) +class EnterDatasetMixin(ABC): + """Mixin for processors that enter a dataset context.""" + + @abstractmethod + def enter_dataset( + self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None + ) -> None: + """Enter dataset context.""" + + +class ExitDatasetMixin(ABC): + """Mixin for processors that exit a dataset context.""" + + @abstractmethod + def exit_dataset( + self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None + ) -> None: + """Exit dataset context.""" + + +class EnterDatasetDictMixin(ABC): + """Mixin for processors that enter a dataset dict context.""" + + @abstractmethod + def enter_dataset_dict(self, dataset_dict) -> None: + """Enter dataset dict context.""" + + +class ExitDatasetDictMixin(ABC): + """Mixin for processors that exit a dataset dict context.""" + + @abstractmethod + def exit_dataset_dict(self, dataset_dict) -> None: + """Exit dataset dict context.""" + + class DatasetDict(datasets.DatasetDict): def __getitem__(self, k) -> Union[Dataset, IterableDataset]: # type: ignore """Returns an individual dataset split.""" diff --git a/src/pie_datasets/document/__init__.py b/src/pie_datasets/document/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pie_datasets/document/processing/regex_partitioner.py b/src/pie_datasets/document/processing/regex_partitioner.py index ae831db9..02e76b74 100644 --- a/src/pie_datasets/document/processing/regex_partitioner.py +++ b/src/pie_datasets/document/processing/regex_partitioner.py @@ -9,7 +9,8 @@ from pytorch_ie.annotations import LabeledSpan from pytorch_ie.documents import TextBasedDocument -from pie_datasets import Dataset, EnterDatasetMixin, ExitDatasetMixin, IterableDataset +from pie_datasets import Dataset, IterableDataset +from pie_datasets.dataset_dict import EnterDatasetMixin, ExitDatasetMixin logger = logging.getLogger(__name__) diff --git a/src/pie_datasets/statistics.py b/src/pie_datasets/statistics.py new file mode 100644 index 00000000..0a1ae0e4 --- /dev/null +++ b/src/pie_datasets/statistics.py @@ -0,0 +1,247 @@ +import logging +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Type, Union + +from pytorch_ie.annotations import Span +from pytorch_ie.core import Document, DocumentStatistic +from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument +from pytorch_ie.utils.hydra import resolve_optional_document_type +from transformers import AutoTokenizer, PreTrainedTokenizer + +from pie_datasets.document.conversion import tokenize_document + +logger = logging.getLogger(__name__) + + +class TokenCountCollector(DocumentStatistic): + """Collects the token count of a field when tokenizing its content with a Huggingface + tokenizer. + + The content of the field should be a string. + """ + + def __init__( + self, + tokenizer: Union[str, PreTrainedTokenizer], + text_field: str = "text", + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + document_type: Optional[Type[Document]] = None, + **kwargs, + ): + if document_type is None and text_field == "text": + document_type = TextBasedDocument + super().__init__(document_type=document_type, **kwargs) + self.tokenizer = ( + AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer + ) + self.tokenizer_kwargs = tokenizer_kwargs or {} + self.text_field = text_field + + def _collect(self, doc: Document) -> int: + text = getattr(doc, self.text_field) + encodings = self.tokenizer(text, **self.tokenizer_kwargs) + tokens = encodings.tokens() + return len(tokens) + + +class FieldLengthCollector(DocumentStatistic): + """Collects the length of a field, e.g. to collect the number the characters in the input text. + + The field should be a list of sized elements. + """ + + def __init__(self, field: str, **kwargs): + super().__init__(**kwargs) + self.field = field + + def _collect(self, doc: Document) -> int: + field_obj = getattr(doc, self.field) + return len(field_obj) + + +class SubFieldLengthCollector(DocumentStatistic): + """Collects the length of a subfield in a field, e.g. to collect the number of arguments of + N-ary relations.""" + + def __init__(self, field: str, subfield: str, **kwargs): + super().__init__(**kwargs) + self.field = field + self.subfield = subfield + + def _collect(self, doc: Document) -> List[int]: + field_obj = getattr(doc, self.field) + lengths = [] + for entry in field_obj: + subfield_obj = getattr(entry, self.subfield) + lengths.append(len(subfield_obj)) + return lengths + + +class SpanLengthCollector(DocumentStatistic): + """Collects the lengths of Span annotations. If labels are provided, the lengths collected per + label. + + If a tokenizer is provided, the span length is calculated in means of tokens, otherwise in + means of characters. + """ + + DEFAULT_AGGREGATION_FUNCTIONS = ["len", "mean", "std", "min", "max"] + + def __init__( + self, + layer: str, + tokenize: bool = False, + tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, + tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None, + labels: Optional[Union[List[str], str]] = None, + label_attribute: str = "label", + tokenize_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.layer = layer + if isinstance(labels, str) and labels != "INFERRED": + raise ValueError("labels must be a list of strings or 'INFERRED'") + if labels == "INFERRED": + logger.warning( + f"Inferring labels with {self.__class__.__name__} from data produces wrong results " + f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " + f"are not included in the calculation. We remove these aggregation functions from " + f"this collector, but be aware that the results may be wrong for your own aggregation " + f"functions that rely on zero values." + ) + self.aggregation_functions: Dict[str, Callable[[List], Any]] = { + name: func + for name, func in self.aggregation_functions.items() + if name not in ["mean", "std", "min"] + } + self.labels = labels + self.label_field = label_attribute + self.tokenize = tokenize + if self.tokenize: + if tokenizer is None: + raise ValueError( + "tokenizer must be provided to calculate the span length in means of tokens" + ) + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + self.tokenizer = tokenizer + resolved_tokenized_document_type = resolve_optional_document_type( + tokenized_document_type + ) + if resolved_tokenized_document_type is None: + raise ValueError( + "tokenized_document_type must be provided to calculate the span length in means of tokens" + ) + if not ( + isinstance(resolved_tokenized_document_type, type) + and issubclass(resolved_tokenized_document_type, TokenBasedDocument) + ): + raise TypeError( + f"tokenized_document_type must be a subclass of TokenBasedDocument, but it is: " + f"{resolved_tokenized_document_type}" + ) + self.tokenized_document_type = resolved_tokenized_document_type + self.tokenize_kwargs = tokenize_kwargs or {} + + def _collect(self, doc: Document) -> Union[List[int], Dict[str, List[int]]]: + docs: Union[List[Document], List[TokenBasedDocument]] + if self.tokenize: + if not isinstance(doc, TextBasedDocument): + raise ValueError( + "doc must be a TextBasedDocument to calculate the span length in means of tokens" + ) + if not isinstance(doc, TextBasedDocument): + raise ValueError( + "doc must be a TextBasedDocument to calculate the span length in means of tokens" + ) + docs = tokenize_document( + doc, + tokenizer=self.tokenizer, + result_document_type=self.tokenized_document_type, + **self.tokenize_kwargs, + ) + else: + docs = [doc] + + values: Dict[str, List[int]] + if isinstance(self.labels, str): + values = defaultdict(list) + else: + values = {label: [] for label in self.labels or ["ALL"]} + for doc in docs: + layer_obj = getattr(doc, self.layer) + for span in layer_obj: + if not isinstance(span, Span): + raise TypeError( + f"span length calculation is not yet supported for {type(span)}" + ) + length = span.end - span.start + if self.labels is None: + label = "ALL" + else: + label = getattr(span, self.label_field) + values[label].append(length) + + return values if self.labels is not None else values["ALL"] + + +class DummyCollector(DocumentStatistic): + """A dummy collector that always returns 1, e.g. to count the number of documents. + + Can be used to count the number of documents. + """ + + DEFAULT_AGGREGATION_FUNCTIONS = ["sum"] + + def _collect(self, doc: Document) -> int: + return 1 + + +class LabelCountCollector(DocumentStatistic): + """Collects the number of field entries per label, e.g. to collect the number of entities per + type. + + The field should be a list of elements with a label attribute. + + Important: To make correct use of the result data, missing values need to be filled with 0, e.g.: + {("ORG",): [2, 3], ("LOC",): [2]} -> {("ORG",): [2, 3], ("LOC",): [2, 0]} + """ + + DEFAULT_AGGREGATION_FUNCTIONS = ["mean", "std", "min", "max", "len", "sum"] + + def __init__( + self, field: str, labels: Union[List[str], str], label_attribute: str = "label", **kwargs + ): + super().__init__(**kwargs) + self.field = field + self.label_attribute = label_attribute + if not (isinstance(labels, list) or labels == "INFERRED"): + raise ValueError("labels must be a list of strings or 'INFERRED'") + if labels == "INFERRED": + logger.warning( + f"Inferring labels with {self.__class__.__name__} from data produces wrong results " + f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " + f"are not included in the calculation. We remove these aggregation functions from " + f"this collector, but be aware that the results may be wrong for your own aggregation " + f"functions that rely on zero values." + ) + self.aggregation_functions: Dict[str, Callable[[List], Any]] = { + name: func + for name, func in self.aggregation_functions.items() + if name not in ["mean", "std", "min"] + } + + self.labels = labels + + def _collect(self, doc: Document) -> Dict[str, int]: + field_obj = getattr(doc, self.field) + counts: Dict[str, int] + if self.labels == "INFERRED": + counts = defaultdict(int) + else: + counts = {label: 0 for label in self.labels} + for elem in field_obj: + label = getattr(elem, self.label_attribute) + counts[label] += 1 + return dict(counts) diff --git a/tests/unit/test_dataset_dict.py b/tests/unit/test_dataset_dict.py index 2f7b97eb..aa382343 100644 --- a/tests/unit/test_dataset_dict.py +++ b/tests/unit/test_dataset_dict.py @@ -9,14 +9,12 @@ from pytorch_ie.core import AnnotationList, Document, annotation_field from pytorch_ie.documents import TextBasedDocument, TextDocument -from pie_datasets import ( - Dataset, - DatasetDict, +from pie_datasets import Dataset, DatasetDict, IterableDataset +from pie_datasets.dataset_dict import ( EnterDatasetDictMixin, EnterDatasetMixin, ExitDatasetDictMixin, ExitDatasetMixin, - IterableDataset, ) from tests import DATASET_BUILDERS_ROOT, FIXTURES_ROOT from tests.conftest import CREATE_FIXTURE_DATA, TestDocument diff --git a/tests/unit/test_statistics.py b/tests/unit/test_statistics.py new file mode 100644 index 00000000..d201dd65 --- /dev/null +++ b/tests/unit/test_statistics.py @@ -0,0 +1,224 @@ +import dataclasses + +import pytest +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument + +from pie_datasets import DatasetDict +from pie_datasets.statistics import ( + DummyCollector, + FieldLengthCollector, + LabelCountCollector, + SpanLengthCollector, + SubFieldLengthCollector, + TokenCountCollector, +) +from tests import FIXTURES_ROOT + + +@pytest.fixture +def dataset(): + @dataclasses.dataclass + class Conll2003Document(TextBasedDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + return DatasetDict.from_json( + data_dir=FIXTURES_ROOT / "dataset_dict" / "conll2003_extract", + document_type=Conll2003Document, + ) + + +def test_statistics(dataset): + statistic = DummyCollector() + values = statistic(dataset) + assert values == {"train": {"sum": 3}, "test": {"sum": 3}, "validation": {"sum": 3}} + + statistic = LabelCountCollector(field="entities", labels=["LOC", "PER", "ORG", "MISC"]) + values = statistic(dataset) + assert values == { + "train": { + "LOC": { + "mean": 0.3333333333333333, + "std": 0.4714045207910317, + "min": 0, + "max": 1, + "len": 3, + "sum": 1, + }, + "PER": { + "mean": 0.3333333333333333, + "std": 0.4714045207910317, + "min": 0, + "max": 1, + "len": 3, + "sum": 1, + }, + "ORG": { + "mean": 0.3333333333333333, + "std": 0.4714045207910317, + "min": 0, + "max": 1, + "len": 3, + "sum": 1, + }, + "MISC": { + "mean": 0.6666666666666666, + "std": 0.9428090415820634, + "min": 0, + "max": 2, + "len": 3, + "sum": 2, + }, + }, + "validation": { + "LOC": { + "mean": 0.3333333333333333, + "std": 0.4714045207910317, + "min": 0, + "max": 1, + "len": 3, + "sum": 1, + }, + "PER": { + "mean": 0.3333333333333333, + "std": 0.4714045207910317, + "min": 0, + "max": 1, + "len": 3, + "sum": 1, + }, + "ORG": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3}, + "MISC": { + "mean": 0.3333333333333333, + "std": 0.4714045207910317, + "min": 0, + "max": 1, + "len": 3, + "sum": 1, + }, + }, + "test": { + "LOC": {"mean": 1.0, "std": 0.816496580927726, "min": 0, "max": 2, "len": 3, "sum": 3}, + "PER": { + "mean": 0.6666666666666666, + "std": 0.4714045207910317, + "min": 0, + "max": 1, + "len": 3, + "sum": 2, + }, + "ORG": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0}, + "MISC": {"mean": 0.0, "std": 0.0, "min": 0, "max": 0, "len": 3, "sum": 0}, + }, + } + + statistic = LabelCountCollector(field="entities", labels="INFERRED") + values = statistic(dataset) + assert values == { + "train": { + "ORG": {"max": 1, "len": 1, "sum": 1}, + "MISC": {"max": 2, "len": 1, "sum": 2}, + "PER": {"max": 1, "len": 1, "sum": 1}, + "LOC": {"max": 1, "len": 1, "sum": 1}, + }, + "validation": { + "ORG": {"max": 2, "len": 2, "sum": 3}, + "LOC": {"max": 1, "len": 1, "sum": 1}, + "MISC": {"max": 1, "len": 1, "sum": 1}, + "PER": {"max": 1, "len": 1, "sum": 1}, + }, + "test": {"LOC": {"max": 2, "len": 2, "sum": 3}, "PER": {"max": 1, "len": 2, "sum": 2}}, + } + + statistic = FieldLengthCollector(field="text") + values = statistic(dataset) + assert values == { + "test": {"max": 57, "mean": 36.0, "min": 11, "std": 18.991226044325487}, + "train": {"max": 48, "mean": 27.333333333333332, "min": 15, "std": 14.70449666674185}, + "validation": {"max": 187, "mean": 89.66666666666667, "min": 17, "std": 71.5603863103665}, + } + + statistic = SpanLengthCollector(layer="entities") + values = statistic(dataset) + assert values == { + "train": {"len": 5, "mean": 7.6, "std": 4.223742416388575, "min": 2, "max": 15}, + "validation": { + "len": 6, + "mean": 10.833333333333334, + "std": 2.9674156357941426, + "min": 6, + "max": 14, + }, + "test": {"len": 5, "mean": 9.4, "std": 5.748043145279966, "min": 5, "max": 20}, + } + + statistic = SpanLengthCollector(layer="entities", labels="INFERRED") + values = statistic(dataset) + assert values == { + "train": { + "ORG": {"max": 2, "len": 1}, + "MISC": {"max": 7, "len": 2}, + "PER": {"max": 15, "len": 1}, + "LOC": {"max": 8, "len": 1}, + }, + "test": { + "LOC": { + "max": 20, + "len": 3, + }, + "PER": {"max": 11, "len": 2}, + }, + "validation": { + "ORG": {"max": 14, "len": 3}, + "LOC": {"max": 6, "len": 1}, + "MISC": {"max": 11, "len": 1}, + "PER": {"max": 12, "len": 1}, + }, + } + + # this is not super useful, we just collect the lengths of the labels, but it is enough to test the code + statistic = SubFieldLengthCollector(field="entities", subfield="label") + values = statistic(dataset) + assert values == { + "test": {"max": 3, "mean": 3.0, "min": 3, "std": 0.0}, + "train": {"max": 4, "mean": 3.4, "min": 3, "std": 0.4898979485566356}, + "validation": {"max": 4, "mean": 3.1666666666666665, "min": 3, "std": 0.3726779962499649}, + } + + +def test_statistics_with_tokenize(dataset): + statistic = TokenCountCollector( + text_field="text", + tokenizer="bert-base-uncased", + tokenizer_kwargs=dict(add_special_tokens=False), + ) + values = statistic(dataset) + assert values == { + "test": {"max": 12, "mean": 9.333333333333334, "min": 4, "std": 3.7712361663282534}, + "train": {"max": 9, "mean": 5.666666666666667, "min": 2, "std": 2.8674417556808756}, + "validation": {"max": 38, "mean": 18.333333333333332, "min": 6, "std": 14.055445761538678}, + } + + @dataclasses.dataclass + class TokenDocumentWithLabeledEntities(TokenBasedDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") + + statistic = SpanLengthCollector( + layer="entities", + tokenize=True, + tokenizer="bert-base-uncased", + tokenized_document_type=TokenDocumentWithLabeledEntities, + ) + values = statistic(dataset) + assert values == { + "test": {"len": 5, "max": 4, "mean": 2.4, "min": 1, "std": 1.2000000000000002}, + "train": {"len": 5, "max": 2, "mean": 1.2, "min": 1, "std": 0.4}, + "validation": { + "len": 6, + "max": 2, + "mean": 1.3333333333333333, + "min": 1, + "std": 0.4714045207910317, + }, + }