diff --git a/src/pytorch_ie/core/metric.py b/src/pytorch_ie/core/metric.py index d7046202..54febd85 100644 --- a/src/pytorch_ie/core/metric.py +++ b/src/pytorch_ie/core/metric.py @@ -1,28 +1,19 @@ from abc import ABC, abstractmethod -from typing import Dict, Generic, Iterable, Optional, Type, TypeVar, Union +from typing import Dict, Generic, Iterable, Optional, TypeVar, Union from pytorch_ie.core.document import Document +from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin T = TypeVar("T") -class DocumentMetric(ABC, Generic[T]): +class DocumentMetric(ABC, RequiresDocumentTypeMixin, Generic[T]): """This defines the interface for a document metric.""" - # The document type that this metric can process. Will be used for auto-conversion, if available. - # Overwrite this if the metric requires a specific document type, e.g. a TextBasedDocument. - DOCUMENT_TYPE: Optional[Type[Document]] = None - def __init__(self): self.reset() self._current_split: Optional[str] = None - @property - def document_type(self) -> Optional[Type[Document]]: - """The document type that this metric can process. Will be used for auto-conversion, if available. - Overwrite this if the document type depends on some parameters of the metric.""" - return self.DOCUMENT_TYPE - @abstractmethod def reset(self) -> None: """Any reset logic that needs to be performed before the metric is called again.""" diff --git a/src/pytorch_ie/core/taskmodule.py b/src/pytorch_ie/core/taskmodule.py index 3f055352..e137ef5e 100644 --- a/src/pytorch_ie/core/taskmodule.py +++ b/src/pytorch_ie/core/taskmodule.py @@ -12,7 +12,6 @@ Optional, Sequence, Tuple, - Type, TypeVar, Union, overload, @@ -24,6 +23,7 @@ from pytorch_ie.core.document import Annotation, Document from pytorch_ie.core.hf_hub_mixin import PieTaskModuleHFHubMixin +from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin from pytorch_ie.core.registrable import Registrable from pytorch_ie.data import Dataset, IterableDataset @@ -149,6 +149,7 @@ class TaskModule( PieTaskModuleHFHubMixin, HyperparametersMixin, Registrable, + RequiresDocumentTypeMixin, Generic[ DocumentType, InputEncoding, @@ -159,16 +160,11 @@ class TaskModule( ], ): PREPARED_ATTRIBUTES: List[str] = [] - DOCUMENT_TYPE: Optional[Type[DocumentType]] = None def __init__(self, encode_document_batch_size: Optional[int] = None, **kwargs): super().__init__(**kwargs) self.encode_document_batch_size = encode_document_batch_size - @property - def document_type(self) -> Optional[Type[DocumentType]]: - return self.DOCUMENT_TYPE - @property def is_prepared(self): """