Skip to content

Commit

Permalink
document_type for DocumentMetrics (#343)
Browse files Browse the repository at this point in the history
* add DOCUMENT_TYPE and document_type to DocumentMetric

* allow to set document_type for DocumentStatistic as init parameter

* set document_type to TextBasedDocument for TokenCountCollector if text_field=text

* implement RequiresDocumentTypeMixin

* use RequiresDocumentTypeMixin in metrics and taskmodules

* default to super().document_type

* resolve document_type in DocumentStatistic.__init__

* adjust method name

* _get_best_dataset_converter_with_types(): accept matches with super classes

* fix test_to_document_type_not_found()
  • Loading branch information
ArneBinder authored Sep 14, 2023
1 parent 8b4ba0e commit 4f26eca
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/pytorch_ie/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .document import Annotation, AnnotationList, Document, annotation_field
from .metric import DocumentMetric
from .model import PyTorchIEModel
from .module_mixins import RequiresDocumentTypeMixin
from .statistic import DocumentStatistic
from .taskmodule import TaskEncoding, TaskModule
3 changes: 2 additions & 1 deletion src/pytorch_ie/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Dict, Generic, Iterable, Optional, TypeVar, Union

from pytorch_ie.core.document import Document
from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin

T = TypeVar("T")


class DocumentMetric(ABC, Generic[T]):
class DocumentMetric(ABC, RequiresDocumentTypeMixin, Generic[T]):
"""This defines the interface for a document metric."""

def __init__(self):
Expand Down
39 changes: 39 additions & 0 deletions src/pytorch_ie/core/module_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
from typing import Optional, Type

from pytorch_ie.core.document import Document
from pytorch_ie.data.dataset_dict import DatasetDict

logger = logging.getLogger(__name__)


class RequiresDocumentTypeMixin:

DOCUMENT_TYPE: Optional[Type[Document]] = None

@property
def document_type(self) -> Optional[Type[Document]]:
return self.DOCUMENT_TYPE

def convert_dataset(self, dataset: DatasetDict) -> DatasetDict:
name = type(self).__name__
# auto-convert the dataset if a document type is specified
if self.document_type is not None:
if issubclass(dataset.document_type, self.document_type):
logger.info(
f"the dataset is already of the document type that is specified by {name}: "
f"{self.document_type}"
)
else:
logger.info(
f"convert the dataset to the document type that is specified by {name}: "
f"{self.document_type}"
)
dataset = dataset.to_document_type(self.document_type)
else:
logger.warning(
f"{name} does not specify a document type. The dataset can not be automatically converted "
f"to a document type."
)

return dataset
14 changes: 12 additions & 2 deletions src/pytorch_ie/core/statistic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union

from pytorch_ie.core.document import Document
from pytorch_ie.core.metric import DocumentMetric
from pytorch_ie.utils.hydra import InstantiationException, resolve_target
from pytorch_ie.utils.hydra import (
InstantiationException,
resolve_optional_document_type,
resolve_target,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,6 +156,7 @@ def __init__(
show_as_markdown: bool = False,
aggregation_functions: Optional[List[str]] = None,
title: Optional[str] = None,
document_type: Optional[Union[Type[Document], str]] = None,
) -> None:
super().__init__()
self.aggregation_functions = {
Expand All @@ -161,6 +166,11 @@ def __init__(
self.show_histogram = show_histogram
self.show_as_markdown = show_as_markdown
self.title = title or self.__class__.__name__
self._document_type = resolve_optional_document_type(document_type)

@property
def document_type(self) -> Optional[Type[Document]]:
return self._document_type or super().document_type

def reset(self) -> None:
self._values: List[Any] = []
Expand Down
8 changes: 2 additions & 6 deletions src/pytorch_ie/core/taskmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
Expand All @@ -24,6 +23,7 @@

from pytorch_ie.core.document import Annotation, Document
from pytorch_ie.core.hf_hub_mixin import PieTaskModuleHFHubMixin
from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin
from pytorch_ie.core.registrable import Registrable
from pytorch_ie.data import Dataset, IterableDataset

Expand Down Expand Up @@ -149,6 +149,7 @@ class TaskModule(
PieTaskModuleHFHubMixin,
HyperparametersMixin,
Registrable,
RequiresDocumentTypeMixin,
Generic[
DocumentType,
InputEncoding,
Expand All @@ -159,16 +160,11 @@ class TaskModule(
],
):
PREPARED_ATTRIBUTES: List[str] = []
DOCUMENT_TYPE: Optional[Type[DocumentType]] = None

def __init__(self, encode_document_batch_size: Optional[int] = None, **kwargs):
super().__init__(**kwargs)
self.encode_document_batch_size = encode_document_batch_size

@property
def document_type(self) -> Optional[Type[DocumentType]]:
return self.DOCUMENT_TYPE

@property
def is_prepared(self):
"""
Expand Down
10 changes: 8 additions & 2 deletions src/pytorch_ie/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def _get_best_dataset_converter_with_types(
# first try to find an exact match
if document_type in dataset.document_converters:
return dataset.document_converters[document_type], document_type, document_type

# then try to find a match with a superclass
for registered_dt, candidate_converter in dataset.document_converters.items():
if issubclass(registered_dt, document_type):
return candidate_converter, document_type, registered_dt

# then try to find a match with a subclass
for registered_dt, candidate_converter in dataset.document_converters.items():
if issubclass(document_type, registered_dt):
Expand Down Expand Up @@ -218,8 +224,8 @@ def dataset_to_document_type(
result = result.cast_document_type(
new_document_type=registered_type, field_mapping=converter, **kwargs
)
# if the requested type is different from the registered type, try to cast (again)
if requested_type != registered_type:
# if the type is not the same or a subclass of the requested type, try to cast (again)
if not issubclass(registered_type, requested_type):
result = result.cast_document_type(new_document_type=requested_type)

# remove the document converters because they are not valid anymore
Expand Down
12 changes: 8 additions & 4 deletions src/pytorch_ie/metrics/statistics.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, Union

from transformers import AutoTokenizer, PreTrainedTokenizer

from pytorch_ie.core import Document, DocumentStatistic
from pytorch_ie.documents import TextBasedDocument


class TokenCountCollector(DocumentStatistic):
"""Collects the token count of a field when tokenizing its content with a Huggingface tokenizer.
The field should be a string.
The content of the field should be a string.
"""

def __init__(
self,
tokenizer: Union[str, PreTrainedTokenizer],
text_field: str,
text_field: str = "text",
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
document_type: Optional[Type[Document]] = None,
**kwargs,
):
super().__init__(**kwargs)
if document_type is None and text_field == "text":
document_type = TextBasedDocument
super().__init__(document_type=document_type, **kwargs)
self.tokenizer = (
AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer
)
Expand Down
24 changes: 17 additions & 7 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,16 +440,26 @@ class TestDocumentWithLabelAndSpans(TestDocumentWithLabel):

def test_to_document_type_not_found(dataset_with_converter_functions):
assert dataset_with_converter_functions.document_type == TestDocument

@dataclass
class TestDocumentWithSpans(TestDocument):
spans: AnnotationList[Span] = annotation_field(target="text")

# The only converter is registered for TestDocumentWithLabel, but we request a conversion to
# TextDocument which is a *superclass* of TestDocumentWithLabel. This is not a valid type,
# so just a simple cast is performed.
converted_dataset = dataset_with_converter_functions.to_document_type(TextDocument)
assert converted_dataset.document_type == TextDocument
# TestDocumentWithSpans. This is not a valid type because it is neither a subclass nor a superclass of
# TestDocumentWithLabel, so just a simple cast is performed.
converted_dataset = dataset_with_converter_functions.to_document_type(TestDocumentWithSpans)
assert converted_dataset.document_type == TestDocumentWithSpans

assert len(converted_dataset.document_converters) == 0
for converted_doc, doc in zip(converted_dataset, dataset_with_converter_functions):
assert isinstance(doc, TestDocument)
assert isinstance(converted_doc, TextDocument)
assert isinstance(converted_doc, TestDocumentWithSpans)
assert converted_doc.text == doc.text
annotation_filed_names = {f.name for f in converted_doc.annotation_fields()}
assert annotation_filed_names == set()
annotation_field_names = {f.name for f in doc.annotation_fields()}
assert annotation_field_names == {"sentences", "entities", "relations"}
converted_annotation_filed_names = {f.name for f in converted_doc.annotation_fields()}
assert converted_annotation_filed_names == {"sentences", "spans", "entities", "relations"}
common_annotation_field_names = annotation_field_names & converted_annotation_filed_names
for afn in common_annotation_field_names:
assert converted_doc[afn] == doc[afn]

0 comments on commit 4f26eca

Please sign in to comment.