diff --git a/src/jaximal/nn.py b/src/jaximal/nn.py index 6c305f9..bbefecb 100644 --- a/src/jaximal/nn.py +++ b/src/jaximal/nn.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from enum import Enum, auto from typing import Any, Callable @@ -65,13 +66,24 @@ def init( return jax.random.normal(key, shape, dtype=dtype) * scaling -class JaximalModule(Jaximal): +class JaximalModule(Jaximal, ABC): + @abstractmethod @staticmethod - def init_state(key: PRNGKeyArray, *args: Any, **kwargs: Any) -> 'JaximalModule': - raise NotImplementedError('This method must be implemented on each subclass') + def init_state(key: PRNGKeyArray, *args: Any, **kwargs: Any) -> 'JaximalModule': ... - def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree: - raise NotImplementedError('This method must be implemented on each subclass') + @abstractmethod + def __call__(self, data: PyTree) -> PyTree: ... + + +class Activation(JaximalModule): + func: Static[Callable[[Array], Array]] + + @staticmethod + def init_state(key: PRNGKeyArray, func: Callable[[Array], Array]) -> 'Activation': + return Activation(func) + + def __call__(self, data: PyTree) -> PyTree: + return jax.tree.map(self.func, data) class Linear(JaximalModule): @@ -95,7 +107,7 @@ def init_state( return Linear(in_dim, out_dim, weights, biases) - def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree: + def __call__(self, data: PyTree) -> PyTree: return data @ self.weights + self.biases @@ -111,7 +123,7 @@ def init_state( modules = list(partial(key) for key, partial in zip(keys, partials)) return Sequential(modules) - def __call__(self, data: PyTree, *args: dict[str, Any], **_: Any) -> PyTree: + def __call__(self, data: PyTree, *args: dict[str, Any]) -> PyTree: assert len(args) == len(self.modules), ( 'Expected `self.modules` and `args` to have the same length ' f'but got {len(self.modules)} and {len(args)}, respectively.'