Skip to content

Commit

Permalink
implement PreparableMixin (#370)
Browse files Browse the repository at this point in the history
* add PreparableMixin and derive Taskmodule from it

* make pre-commit happy

* derive PieTaskModuleHFHubMixin from PreparableMixin

* revert: derive PieTaskModuleHFHubMixin from PreparableMixin

* rename _assert_is_prepared() to assert_is_prepared() and call it at the start of Taskmodule.encode(); improve log / error messages

* make pre-commit happy
  • Loading branch information
ArneBinder authored Nov 8, 2023
1 parent ef4c574 commit a25bf0f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 67 deletions.
5 changes: 0 additions & 5 deletions src/pytorch_ie/core/hf_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,6 @@ def __init__(self, *args, **kwargs):
def _save_pretrained(self, save_directory):
return None

def post_prepare(self) -> None:
pass

@classmethod
def _from_pretrained(
cls,
Expand All @@ -433,6 +430,4 @@ def _from_pretrained(
config.pop(cls.config_type_key)

taskmodule = cls(**config)
taskmodule.post_prepare()

return taskmodule
63 changes: 62 additions & 1 deletion src/pytorch_ie/core/module_mixins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Type
from typing import List, Optional, Type

from pytorch_ie.core.document import Document

Expand Down Expand Up @@ -36,3 +36,64 @@ def convert_dataset(self, dataset: "pie_datasets.DatasetDict") -> "pie_datasets.
)

return dataset


class PreparableMixin:
# list of attribute names that need to be set by _prepare()
PREPARED_ATTRIBUTES: List[str] = []

@property
def is_prepared(self):
"""
Returns True, iff all attributes listed in PREPARED_ATTRIBUTES are set.
Note: Attributes set to None are not considered to be prepared!
"""
return all(
getattr(self, attribute, None) is not None for attribute in self.PREPARED_ATTRIBUTES
)

@property
def prepared_attributes(self):
if not self.is_prepared:
raise Exception("The module is not prepared.")
return {param: getattr(self, param) for param in self.PREPARED_ATTRIBUTES}

def _prepare(self, *args, **kwargs):
"""
This method needs to set all attributes listed in PREPARED_ATTRIBUTES.
"""
pass

def _post_prepare(self):
"""
Any code to do further one-time setup, but that requires the prepared attributes.
"""
pass

def assert_is_prepared(self, msg: Optional[str] = None):
if not self.is_prepared:
attributes_not_prepared = [
param for param in self.PREPARED_ATTRIBUTES if getattr(self, param, None) is None
]
raise Exception(
f"{msg or ''} Required attributes that are not set: {str(attributes_not_prepared)}"
)

def post_prepare(self):
self.assert_is_prepared()
self._post_prepare()

def prepare(self, *args, **kwargs) -> None:
if self.is_prepared:
if len(self.PREPARED_ATTRIBUTES) > 0:
msg = f"The {self.__class__.__name__} is already prepared, do not prepare again."
for k, v in self.prepared_attributes.items():
msg += f"\n{k} = {str(v)}"
logger.warning(msg)
else:
self._prepare(*args, **kwargs)
self.assert_is_prepared(
msg=f"_prepare() was called, but the {self.__class__.__name__} is not prepared."
)
self._post_prepare()
return None
67 changes: 6 additions & 61 deletions src/pytorch_ie/core/taskmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from pytorch_ie.core.document import Annotation, Document
from pytorch_ie.core.hf_hub_mixin import PieTaskModuleHFHubMixin
from pytorch_ie.core.module_mixins import WithDocumentTypeMixin
from pytorch_ie.core.module_mixins import PreparableMixin, WithDocumentTypeMixin
from pytorch_ie.core.registrable import Registrable

"""
Expand Down Expand Up @@ -134,6 +134,7 @@ class TaskModule(
HyperparametersMixin,
Registrable,
WithDocumentTypeMixin,
PreparableMixin,
Generic[
DocumentType,
InputEncoding,
Expand All @@ -143,68 +144,10 @@ class TaskModule(
TaskOutput,
],
):
PREPARED_ATTRIBUTES: List[str] = []

def __init__(self, encode_document_batch_size: Optional[int] = None, **kwargs):
super().__init__(**kwargs)
self.encode_document_batch_size = encode_document_batch_size

@property
def is_prepared(self):
"""
Returns True, iff all attributes listed in PREPARED_ATTRIBUTES are set.
Note: Attributes set to None are not considered to be prepared!
"""
return all(
getattr(self, attribute, None) is not None for attribute in self.PREPARED_ATTRIBUTES
)

@property
def prepared_attributes(self):
if not self.is_prepared:
raise Exception("The taskmodule is not prepared.")
return {param: getattr(self, param) for param in self.PREPARED_ATTRIBUTES}

def _prepare(self, documents: Sequence[DocumentType]):
"""
This method needs to set all attributes listed in PREPARED_ATTRIBUTES.
"""
pass

def _post_prepare(self):
"""
Any code to do further one-time setup, but that requires the prepared attributes.
"""
pass

def _assert_is_prepared(self, msg: Optional[str] = None):
if not self.is_prepared:
attributes_not_prepared = [
param for param in self.PREPARED_ATTRIBUTES if getattr(self, param, None) is None
]
raise Exception(
f"{msg or ''} Required attributes that are not set: {str(attributes_not_prepared)}"
)

def post_prepare(self):
self._assert_is_prepared()
self._post_prepare()

def prepare(self, documents: Sequence[DocumentType]) -> None:
if self.is_prepared:
if len(self.PREPARED_ATTRIBUTES) > 0:
msg = "The taskmodule is already prepared, do not prepare again."
for k, v in self.prepared_attributes.items():
msg += f"\n{k} = {str(v)}"
logger.warning(msg)
else:
self._prepare(documents)
self._assert_is_prepared(
msg="_prepare() was called, but the taskmodule is not prepared."
)
self._post_prepare()
return None

def _config(self) -> Dict[str, Any]:
config = super()._config() or {}
config[self.config_type_key] = TaskModule.name_for_object_class(self)
Expand All @@ -220,8 +163,8 @@ def _from_pretrained(
*args,
**kwargs,
):
taskmodule = super()._from_pretrained(*args, **kwargs)
taskmodule._post_prepare()
taskmodule: TaskModule = super()._from_pretrained(*args, **kwargs)
taskmodule.post_prepare()
return taskmodule

def batch_encode(
Expand Down Expand Up @@ -290,6 +233,8 @@ def encode(
TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
]:
self.assert_is_prepared()

# backwards compatibility
if as_task_encoding_sequence is None:
as_task_encoding_sequence = not encode_target
Expand Down

0 comments on commit a25bf0f

Please sign in to comment.