Skip to content

Commit

Permalink
Refactoring to use MetaLayer class (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
DiogenesAnalytics committed Dec 17, 2023
1 parent 1642382 commit 15c43bc
Showing 1 changed file with 32 additions and 15 deletions.
47 changes: 32 additions & 15 deletions src/autoencoder/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,36 @@
from typing import Any
from typing import Dict
from typing import Generator
from typing import Iterator
from typing import Optional
from typing import Tuple

import keras
from keras.layers import Layer
from typing_extensions import TypeAlias


# custom types
DefaultParams: TypeAlias = Dict[str, Tuple[Layer, Dict[str, Any]]]
@dataclass
class MetaLayer:
"""Container for a keras Layer and its kwargs."""

layer: Layer
params: Dict[str, Any]

def __iter__(self) -> Iterator[Any]:
"""Define iteration behavior."""
return iter((self.layer, self.params))


class Encode(MetaLayer):
"""Designate a meta layer as an ecoding layer."""

class BaseLayerParams(ABC):
"""Autoencoder layers hyperparameters configuration base class."""

class Decode(MetaLayer):
"""Designate a meta layer as a decoding layer."""


class BaseModelParams(ABC):
"""Autoencoder model layer hyperparameters configuration base class."""

@abstractmethod
def __init__(self, **kwargs: Dict[str, Any]) -> None:
Expand All @@ -28,7 +44,7 @@ def __init__(self, **kwargs: Dict[str, Any]) -> None:

@property
@abstractmethod
def default_parameters(self) -> DefaultParams:
def default_parameters(self) -> Dict[str, MetaLayer]:
"""Defines the required default layer parameters attribute."""
# NOTE: this dictionary sets layer order used to build the keras.Model
pass
Expand All @@ -53,24 +69,25 @@ def _filter_layer_attrs(self) -> Generator[Tuple[str, Dict[str, Any]], None, Non
# finally get value of constructor args
yield layer_id, self.__dict__[layer_id]

def _update_layer_params(
self,
) -> Generator[Tuple[Layer, Dict[str, Any]], None, None]:
def _update_layer_params(self) -> Generator[MetaLayer, None, None]:
"""Update default layer parameters values."""
# get layer instance attrs and their values
for attr, value in self._filter_layer_attrs():
# unpack default parameters
layer, params = self.default_parameters[attr]

# get copy of default params
default_params_copy = params.copy()

# check if none
if value is not None:
# merge instance onto default
params |= value
# update default with any user supplied kwargs
default_params_copy |= value

# generate
yield layer, params
yield MetaLayer(layer, default_params_copy)

def _build_instance_params(self) -> Tuple[Tuple[Layer, Dict[str, Any]], ...]:
def _build_instance_params(self) -> Tuple[MetaLayer, ...]:
"""Create mutable sequence of layer params for instance."""
return tuple(self._update_layer_params())

Expand All @@ -79,7 +96,7 @@ def _build_instance_params(self) -> Tuple[Tuple[Layer, Dict[str, Any]], ...]:
class BaseAutoencoder(ABC):
"""Autoencoder base class."""

model_config: Optional[BaseLayerParams] = None
model_config: Optional[BaseModelParams] = None

def __post_init__(self) -> None:
"""Setup autoencoder model."""
Expand All @@ -93,7 +110,7 @@ def __post_init__(self) -> None:

@property
@abstractmethod
def _default_config(self) -> BaseLayerParams:
def _default_config(self) -> BaseModelParams:
"""Defines the default layer parameters attribute."""
pass

Expand Down

0 comments on commit 15c43bc

Please sign in to comment.