diff --git a/src/autoencoder/model/base.py b/src/autoencoder/model/base.py index d648dbe..1f74cab 100644 --- a/src/autoencoder/model/base.py +++ b/src/autoencoder/model/base.py @@ -6,20 +6,36 @@ from typing import Any from typing import Dict from typing import Generator +from typing import Iterator from typing import Optional from typing import Tuple import keras from keras.layers import Layer -from typing_extensions import TypeAlias -# custom types -DefaultParams: TypeAlias = Dict[str, Tuple[Layer, Dict[str, Any]]] +@dataclass +class MetaLayer: + """Container for a keras Layer and its kwargs.""" + + layer: Layer + params: Dict[str, Any] + + def __iter__(self) -> Iterator[Any]: + """Define iteration behavior.""" + return iter((self.layer, self.params)) + +class Encode(MetaLayer): + """Designate a meta layer as an ecoding layer.""" -class BaseLayerParams(ABC): - """Autoencoder layers hyperparameters configuration base class.""" + +class Decode(MetaLayer): + """Designate a meta layer as a decoding layer.""" + + +class BaseModelParams(ABC): + """Autoencoder model layer hyperparameters configuration base class.""" @abstractmethod def __init__(self, **kwargs: Dict[str, Any]) -> None: @@ -28,7 +44,7 @@ def __init__(self, **kwargs: Dict[str, Any]) -> None: @property @abstractmethod - def default_parameters(self) -> DefaultParams: + def default_parameters(self) -> Dict[str, MetaLayer]: """Defines the required default layer parameters attribute.""" # NOTE: this dictionary sets layer order used to build the keras.Model pass @@ -53,24 +69,25 @@ def _filter_layer_attrs(self) -> Generator[Tuple[str, Dict[str, Any]], None, Non # finally get value of constructor args yield layer_id, self.__dict__[layer_id] - def _update_layer_params( - self, - ) -> Generator[Tuple[Layer, Dict[str, Any]], None, None]: + def _update_layer_params(self) -> Generator[MetaLayer, None, None]: """Update default layer parameters values.""" # get layer instance attrs and their values for attr, value in self._filter_layer_attrs(): # unpack default parameters layer, params = self.default_parameters[attr] + # get copy of default params + default_params_copy = params.copy() + # check if none if value is not None: - # merge instance onto default - params |= value + # update default with any user supplied kwargs + default_params_copy |= value # generate - yield layer, params + yield MetaLayer(layer, default_params_copy) - def _build_instance_params(self) -> Tuple[Tuple[Layer, Dict[str, Any]], ...]: + def _build_instance_params(self) -> Tuple[MetaLayer, ...]: """Create mutable sequence of layer params for instance.""" return tuple(self._update_layer_params()) @@ -79,7 +96,7 @@ def _build_instance_params(self) -> Tuple[Tuple[Layer, Dict[str, Any]], ...]: class BaseAutoencoder(ABC): """Autoencoder base class.""" - model_config: Optional[BaseLayerParams] = None + model_config: Optional[BaseModelParams] = None def __post_init__(self) -> None: """Setup autoencoder model.""" @@ -93,7 +110,7 @@ def __post_init__(self) -> None: @property @abstractmethod - def _default_config(self) -> BaseLayerParams: + def _default_config(self) -> BaseModelParams: """Defines the default layer parameters attribute.""" pass