Skip to content

Commit

Permalink
Made JaximalModule an abc.ABC (abstract class) and implemented th…
Browse files Browse the repository at this point in the history
…e `Activation` module.
  • Loading branch information
ArvinSKushwaha committed Jul 26, 2024
1 parent 94ee121 commit a2b3a23
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/jaximal/nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Any, Callable

Expand Down Expand Up @@ -65,13 +66,24 @@ def init(
return jax.random.normal(key, shape, dtype=dtype) * scaling


class JaximalModule(Jaximal):
class JaximalModule(Jaximal, ABC):
@abstractmethod
@staticmethod
def init_state(key: PRNGKeyArray, *args: Any, **kwargs: Any) -> 'JaximalModule':
raise NotImplementedError('This method must be implemented on each subclass')
def init_state(key: PRNGKeyArray, *args: Any, **kwargs: Any) -> 'JaximalModule': ...

def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree:
raise NotImplementedError('This method must be implemented on each subclass')
@abstractmethod
def __call__(self, data: PyTree) -> PyTree: ...


class Activation(JaximalModule):
func: Static[Callable[[Array], Array]]

@staticmethod
def init_state(key: PRNGKeyArray, func: Callable[[Array], Array]) -> 'Activation':
return Activation(func)

def __call__(self, data: PyTree) -> PyTree:
return jax.tree.map(self.func, data)


class Linear(JaximalModule):
Expand All @@ -95,7 +107,7 @@ def init_state(

return Linear(in_dim, out_dim, weights, biases)

def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree:
def __call__(self, data: PyTree) -> PyTree:
return data @ self.weights + self.biases


Expand All @@ -111,7 +123,7 @@ def init_state(
modules = list(partial(key) for key, partial in zip(keys, partials))
return Sequential(modules)

def __call__(self, data: PyTree, *args: dict[str, Any], **_: Any) -> PyTree:
def __call__(self, data: PyTree, *args: dict[str, Any]) -> PyTree:
assert len(args) == len(self.modules), (
'Expected `self.modules` and `args` to have the same length '
f'but got {len(self.modules)} and {len(args)}, respectively.'
Expand Down

0 comments on commit a2b3a23

Please sign in to comment.