diff --git a/src/pytorch_ie/core/hf_hub_mixin.py b/src/pytorch_ie/core/hf_hub_mixin.py index bc49d2cb..a737a9c5 100644 --- a/src/pytorch_ie/core/hf_hub_mixin.py +++ b/src/pytorch_ie/core/hf_hub_mixin.py @@ -407,9 +407,6 @@ def __init__(self, *args, **kwargs): def _save_pretrained(self, save_directory): return None - def post_prepare(self) -> None: - pass - @classmethod def _from_pretrained( cls, @@ -433,6 +430,4 @@ def _from_pretrained( config.pop(cls.config_type_key) taskmodule = cls(**config) - taskmodule.post_prepare() - return taskmodule diff --git a/src/pytorch_ie/core/module_mixins.py b/src/pytorch_ie/core/module_mixins.py index 7811947a..01c4bc3f 100644 --- a/src/pytorch_ie/core/module_mixins.py +++ b/src/pytorch_ie/core/module_mixins.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Type +from typing import List, Optional, Type from pytorch_ie.core.document import Document @@ -36,3 +36,64 @@ def convert_dataset(self, dataset: "pie_datasets.DatasetDict") -> "pie_datasets. ) return dataset + + +class PreparableMixin: + # list of attribute names that need to be set by _prepare() + PREPARED_ATTRIBUTES: List[str] = [] + + @property + def is_prepared(self): + """ + Returns True, iff all attributes listed in PREPARED_ATTRIBUTES are set. + Note: Attributes set to None are not considered to be prepared! + """ + return all( + getattr(self, attribute, None) is not None for attribute in self.PREPARED_ATTRIBUTES + ) + + @property + def prepared_attributes(self): + if not self.is_prepared: + raise Exception("The module is not prepared.") + return {param: getattr(self, param) for param in self.PREPARED_ATTRIBUTES} + + def _prepare(self, *args, **kwargs): + """ + This method needs to set all attributes listed in PREPARED_ATTRIBUTES. + """ + pass + + def _post_prepare(self): + """ + Any code to do further one-time setup, but that requires the prepared attributes. + """ + pass + + def assert_is_prepared(self, msg: Optional[str] = None): + if not self.is_prepared: + attributes_not_prepared = [ + param for param in self.PREPARED_ATTRIBUTES if getattr(self, param, None) is None + ] + raise Exception( + f"{msg or ''} Required attributes that are not set: {str(attributes_not_prepared)}" + ) + + def post_prepare(self): + self.assert_is_prepared() + self._post_prepare() + + def prepare(self, *args, **kwargs) -> None: + if self.is_prepared: + if len(self.PREPARED_ATTRIBUTES) > 0: + msg = f"The {self.__class__.__name__} is already prepared, do not prepare again." + for k, v in self.prepared_attributes.items(): + msg += f"\n{k} = {str(v)}" + logger.warning(msg) + else: + self._prepare(*args, **kwargs) + self.assert_is_prepared( + msg=f"_prepare() was called, but the {self.__class__.__name__} is not prepared." + ) + self._post_prepare() + return None diff --git a/src/pytorch_ie/core/taskmodule.py b/src/pytorch_ie/core/taskmodule.py index 64314346..b09d19d5 100644 --- a/src/pytorch_ie/core/taskmodule.py +++ b/src/pytorch_ie/core/taskmodule.py @@ -10,7 +10,7 @@ from pytorch_ie.core.document import Annotation, Document from pytorch_ie.core.hf_hub_mixin import PieTaskModuleHFHubMixin -from pytorch_ie.core.module_mixins import WithDocumentTypeMixin +from pytorch_ie.core.module_mixins import PreparableMixin, WithDocumentTypeMixin from pytorch_ie.core.registrable import Registrable """ @@ -134,6 +134,7 @@ class TaskModule( HyperparametersMixin, Registrable, WithDocumentTypeMixin, + PreparableMixin, Generic[ DocumentType, InputEncoding, @@ -143,68 +144,10 @@ class TaskModule( TaskOutput, ], ): - PREPARED_ATTRIBUTES: List[str] = [] - def __init__(self, encode_document_batch_size: Optional[int] = None, **kwargs): super().__init__(**kwargs) self.encode_document_batch_size = encode_document_batch_size - @property - def is_prepared(self): - """ - Returns True, iff all attributes listed in PREPARED_ATTRIBUTES are set. - Note: Attributes set to None are not considered to be prepared! - """ - return all( - getattr(self, attribute, None) is not None for attribute in self.PREPARED_ATTRIBUTES - ) - - @property - def prepared_attributes(self): - if not self.is_prepared: - raise Exception("The taskmodule is not prepared.") - return {param: getattr(self, param) for param in self.PREPARED_ATTRIBUTES} - - def _prepare(self, documents: Sequence[DocumentType]): - """ - This method needs to set all attributes listed in PREPARED_ATTRIBUTES. - """ - pass - - def _post_prepare(self): - """ - Any code to do further one-time setup, but that requires the prepared attributes. - """ - pass - - def _assert_is_prepared(self, msg: Optional[str] = None): - if not self.is_prepared: - attributes_not_prepared = [ - param for param in self.PREPARED_ATTRIBUTES if getattr(self, param, None) is None - ] - raise Exception( - f"{msg or ''} Required attributes that are not set: {str(attributes_not_prepared)}" - ) - - def post_prepare(self): - self._assert_is_prepared() - self._post_prepare() - - def prepare(self, documents: Sequence[DocumentType]) -> None: - if self.is_prepared: - if len(self.PREPARED_ATTRIBUTES) > 0: - msg = "The taskmodule is already prepared, do not prepare again." - for k, v in self.prepared_attributes.items(): - msg += f"\n{k} = {str(v)}" - logger.warning(msg) - else: - self._prepare(documents) - self._assert_is_prepared( - msg="_prepare() was called, but the taskmodule is not prepared." - ) - self._post_prepare() - return None - def _config(self) -> Dict[str, Any]: config = super()._config() or {} config[self.config_type_key] = TaskModule.name_for_object_class(self) @@ -220,8 +163,8 @@ def _from_pretrained( *args, **kwargs, ): - taskmodule = super()._from_pretrained(*args, **kwargs) - taskmodule._post_prepare() + taskmodule: TaskModule = super()._from_pretrained(*args, **kwargs) + taskmodule.post_prepare() return taskmodule def batch_encode( @@ -290,6 +233,8 @@ def encode( TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]], IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]], ]: + self.assert_is_prepared() + # backwards compatibility if as_task_encoding_sequence is None: as_task_encoding_sequence = not encode_target