diff --git a/src/pytorch_ie/core/__init__.py b/src/pytorch_ie/core/__init__.py index 7f7141a9..989047c8 100644 --- a/src/pytorch_ie/core/__init__.py +++ b/src/pytorch_ie/core/__init__.py @@ -1,5 +1,6 @@ from .document import Annotation, AnnotationList, Document, annotation_field from .metric import DocumentMetric from .model import PyTorchIEModel +from .module_mixins import RequiresDocumentTypeMixin from .statistic import DocumentStatistic from .taskmodule import TaskEncoding, TaskModule diff --git a/src/pytorch_ie/core/metric.py b/src/pytorch_ie/core/metric.py index 639bc879..54febd85 100644 --- a/src/pytorch_ie/core/metric.py +++ b/src/pytorch_ie/core/metric.py @@ -2,11 +2,12 @@ from typing import Dict, Generic, Iterable, Optional, TypeVar, Union from pytorch_ie.core.document import Document +from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin T = TypeVar("T") -class DocumentMetric(ABC, Generic[T]): +class DocumentMetric(ABC, RequiresDocumentTypeMixin, Generic[T]): """This defines the interface for a document metric.""" def __init__(self): diff --git a/src/pytorch_ie/core/module_mixins.py b/src/pytorch_ie/core/module_mixins.py new file mode 100644 index 00000000..4d151bc8 --- /dev/null +++ b/src/pytorch_ie/core/module_mixins.py @@ -0,0 +1,39 @@ +import logging +from typing import Optional, Type + +from pytorch_ie.core.document import Document +from pytorch_ie.data.dataset_dict import DatasetDict + +logger = logging.getLogger(__name__) + + +class RequiresDocumentTypeMixin: + + DOCUMENT_TYPE: Optional[Type[Document]] = None + + @property + def document_type(self) -> Optional[Type[Document]]: + return self.DOCUMENT_TYPE + + def convert_dataset(self, dataset: DatasetDict) -> DatasetDict: + name = type(self).__name__ + # auto-convert the dataset if a document type is specified + if self.document_type is not None: + if issubclass(dataset.document_type, self.document_type): + logger.info( + f"the dataset is already of the document type that is specified by {name}: " + f"{self.document_type}" + ) + else: + logger.info( + f"convert the dataset to the document type that is specified by {name}: " + f"{self.document_type}" + ) + dataset = dataset.to_document_type(self.document_type) + else: + logger.warning( + f"{name} does not specify a document type. The dataset can not be automatically converted " + f"to a document type." + ) + + return dataset diff --git a/src/pytorch_ie/core/statistic.py b/src/pytorch_ie/core/statistic.py index e926d7ad..b1baed4a 100644 --- a/src/pytorch_ie/core/statistic.py +++ b/src/pytorch_ie/core/statistic.py @@ -1,11 +1,15 @@ import logging from abc import abstractmethod from collections import defaultdict -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union from pytorch_ie.core.document import Document from pytorch_ie.core.metric import DocumentMetric -from pytorch_ie.utils.hydra import InstantiationException, resolve_target +from pytorch_ie.utils.hydra import ( + InstantiationException, + resolve_optional_document_type, + resolve_target, +) logger = logging.getLogger(__name__) @@ -152,6 +156,7 @@ def __init__( show_as_markdown: bool = False, aggregation_functions: Optional[List[str]] = None, title: Optional[str] = None, + document_type: Optional[Union[Type[Document], str]] = None, ) -> None: super().__init__() self.aggregation_functions = { @@ -161,6 +166,11 @@ def __init__( self.show_histogram = show_histogram self.show_as_markdown = show_as_markdown self.title = title or self.__class__.__name__ + self._document_type = resolve_optional_document_type(document_type) + + @property + def document_type(self) -> Optional[Type[Document]]: + return self._document_type or super().document_type def reset(self) -> None: self._values: List[Any] = [] diff --git a/src/pytorch_ie/core/taskmodule.py b/src/pytorch_ie/core/taskmodule.py index 3f055352..e137ef5e 100644 --- a/src/pytorch_ie/core/taskmodule.py +++ b/src/pytorch_ie/core/taskmodule.py @@ -12,7 +12,6 @@ Optional, Sequence, Tuple, - Type, TypeVar, Union, overload, @@ -24,6 +23,7 @@ from pytorch_ie.core.document import Annotation, Document from pytorch_ie.core.hf_hub_mixin import PieTaskModuleHFHubMixin +from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin from pytorch_ie.core.registrable import Registrable from pytorch_ie.data import Dataset, IterableDataset @@ -149,6 +149,7 @@ class TaskModule( PieTaskModuleHFHubMixin, HyperparametersMixin, Registrable, + RequiresDocumentTypeMixin, Generic[ DocumentType, InputEncoding, @@ -159,16 +160,11 @@ class TaskModule( ], ): PREPARED_ATTRIBUTES: List[str] = [] - DOCUMENT_TYPE: Optional[Type[DocumentType]] = None def __init__(self, encode_document_batch_size: Optional[int] = None, **kwargs): super().__init__(**kwargs) self.encode_document_batch_size = encode_document_batch_size - @property - def document_type(self) -> Optional[Type[DocumentType]]: - return self.DOCUMENT_TYPE - @property def is_prepared(self): """ diff --git a/src/pytorch_ie/data/dataset.py b/src/pytorch_ie/data/dataset.py index c64281f4..781b5b49 100644 --- a/src/pytorch_ie/data/dataset.py +++ b/src/pytorch_ie/data/dataset.py @@ -159,6 +159,12 @@ def _get_best_dataset_converter_with_types( # first try to find an exact match if document_type in dataset.document_converters: return dataset.document_converters[document_type], document_type, document_type + + # then try to find a match with a superclass + for registered_dt, candidate_converter in dataset.document_converters.items(): + if issubclass(registered_dt, document_type): + return candidate_converter, document_type, registered_dt + # then try to find a match with a subclass for registered_dt, candidate_converter in dataset.document_converters.items(): if issubclass(document_type, registered_dt): @@ -218,8 +224,8 @@ def dataset_to_document_type( result = result.cast_document_type( new_document_type=registered_type, field_mapping=converter, **kwargs ) - # if the requested type is different from the registered type, try to cast (again) - if requested_type != registered_type: + # if the type is not the same or a subclass of the requested type, try to cast (again) + if not issubclass(registered_type, requested_type): result = result.cast_document_type(new_document_type=requested_type) # remove the document converters because they are not valid anymore diff --git a/src/pytorch_ie/metrics/statistics.py b/src/pytorch_ie/metrics/statistics.py index 718e515e..5ff68eba 100644 --- a/src/pytorch_ie/metrics/statistics.py +++ b/src/pytorch_ie/metrics/statistics.py @@ -1,25 +1,29 @@ from collections import defaultdict -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union from transformers import AutoTokenizer, PreTrainedTokenizer from pytorch_ie.core import Document, DocumentStatistic +from pytorch_ie.documents import TextBasedDocument class TokenCountCollector(DocumentStatistic): """Collects the token count of a field when tokenizing its content with a Huggingface tokenizer. - The field should be a string. + The content of the field should be a string. """ def __init__( self, tokenizer: Union[str, PreTrainedTokenizer], - text_field: str, + text_field: str = "text", tokenizer_kwargs: Optional[Dict[str, Any]] = None, + document_type: Optional[Type[Document]] = None, **kwargs, ): - super().__init__(**kwargs) + if document_type is None and text_field == "text": + document_type = TextBasedDocument + super().__init__(document_type=document_type, **kwargs) self.tokenizer = ( AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer ) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 2d390e2b..6e0cc40d 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -440,16 +440,26 @@ class TestDocumentWithLabelAndSpans(TestDocumentWithLabel): def test_to_document_type_not_found(dataset_with_converter_functions): assert dataset_with_converter_functions.document_type == TestDocument + + @dataclass + class TestDocumentWithSpans(TestDocument): + spans: AnnotationList[Span] = annotation_field(target="text") + # The only converter is registered for TestDocumentWithLabel, but we request a conversion to - # TextDocument which is a *superclass* of TestDocumentWithLabel. This is not a valid type, - # so just a simple cast is performed. - converted_dataset = dataset_with_converter_functions.to_document_type(TextDocument) - assert converted_dataset.document_type == TextDocument + # TestDocumentWithSpans. This is not a valid type because it is neither a subclass nor a superclass of + # TestDocumentWithLabel, so just a simple cast is performed. + converted_dataset = dataset_with_converter_functions.to_document_type(TestDocumentWithSpans) + assert converted_dataset.document_type == TestDocumentWithSpans assert len(converted_dataset.document_converters) == 0 for converted_doc, doc in zip(converted_dataset, dataset_with_converter_functions): assert isinstance(doc, TestDocument) - assert isinstance(converted_doc, TextDocument) + assert isinstance(converted_doc, TestDocumentWithSpans) assert converted_doc.text == doc.text - annotation_filed_names = {f.name for f in converted_doc.annotation_fields()} - assert annotation_filed_names == set() + annotation_field_names = {f.name for f in doc.annotation_fields()} + assert annotation_field_names == {"sentences", "entities", "relations"} + converted_annotation_filed_names = {f.name for f in converted_doc.annotation_fields()} + assert converted_annotation_filed_names == {"sentences", "spans", "entities", "relations"} + common_annotation_field_names = annotation_field_names & converted_annotation_filed_names + for afn in common_annotation_field_names: + assert converted_doc[afn] == doc[afn]