Skip to content

Commit

Permalink
use RequiresDocumentTypeMixin in metrics and taskmodules
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Sep 14, 2023
1 parent c20ca66 commit 56478ef
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 18 deletions.
15 changes: 3 additions & 12 deletions src/pytorch_ie/core/metric.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
from abc import ABC, abstractmethod
from typing import Dict, Generic, Iterable, Optional, Type, TypeVar, Union
from typing import Dict, Generic, Iterable, Optional, TypeVar, Union

from pytorch_ie.core.document import Document
from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin

T = TypeVar("T")


class DocumentMetric(ABC, Generic[T]):
class DocumentMetric(ABC, RequiresDocumentTypeMixin, Generic[T]):
"""This defines the interface for a document metric."""

# The document type that this metric can process. Will be used for auto-conversion, if available.
# Overwrite this if the metric requires a specific document type, e.g. a TextBasedDocument.
DOCUMENT_TYPE: Optional[Type[Document]] = None

def __init__(self):
self.reset()
self._current_split: Optional[str] = None

@property
def document_type(self) -> Optional[Type[Document]]:
"""The document type that this metric can process. Will be used for auto-conversion, if available.
Overwrite this if the document type depends on some parameters of the metric."""
return self.DOCUMENT_TYPE

@abstractmethod
def reset(self) -> None:
"""Any reset logic that needs to be performed before the metric is called again."""
Expand Down
8 changes: 2 additions & 6 deletions src/pytorch_ie/core/taskmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
Expand All @@ -24,6 +23,7 @@

from pytorch_ie.core.document import Annotation, Document
from pytorch_ie.core.hf_hub_mixin import PieTaskModuleHFHubMixin
from pytorch_ie.core.module_mixins import RequiresDocumentTypeMixin
from pytorch_ie.core.registrable import Registrable
from pytorch_ie.data import Dataset, IterableDataset

Expand Down Expand Up @@ -149,6 +149,7 @@ class TaskModule(
PieTaskModuleHFHubMixin,
HyperparametersMixin,
Registrable,
RequiresDocumentTypeMixin,
Generic[
DocumentType,
InputEncoding,
Expand All @@ -159,16 +160,11 @@ class TaskModule(
],
):
PREPARED_ATTRIBUTES: List[str] = []
DOCUMENT_TYPE: Optional[Type[DocumentType]] = None

def __init__(self, encode_document_batch_size: Optional[int] = None, **kwargs):
super().__init__(**kwargs)
self.encode_document_batch_size = encode_document_batch_size

@property
def document_type(self) -> Optional[Type[DocumentType]]:
return self.DOCUMENT_TYPE

@property
def is_prepared(self):
"""
Expand Down

0 comments on commit 56478ef

Please sign in to comment.