-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reorganized serialization and added basic NN features
- Loading branch information
1 parent
ea32d38
commit 94ee121
Showing
7 changed files
with
433 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from . import core as core | ||
from .core import Jaximal as Jaximal | ||
from .core import Static as Static | ||
from .nn import JaximalModule as JaximalModule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from enum import Enum, auto | ||
from typing import Any, Callable | ||
|
||
import jax | ||
|
||
from jax import numpy as np | ||
from jaxtyping import Array, Float, PRNGKeyArray, PyTree | ||
|
||
from jaximal.core import Jaximal, Static | ||
|
||
|
||
class WeightInitialization(Enum): | ||
Zero = auto() | ||
RandomUniform = auto() | ||
RandomNormal = auto() | ||
GlorotUniform = auto() | ||
GlorotNormal = auto() | ||
HeUniform = auto() | ||
HeNormal = auto() | ||
|
||
def init( | ||
self, | ||
key: PRNGKeyArray, | ||
shape: tuple[int, ...], | ||
fan_in: int, | ||
fan_out: int, | ||
dtype: np.dtype = np.float_, | ||
) -> Float[Array, '*']: | ||
match self: | ||
case WeightInitialization.Zero: | ||
return np.zeros(shape, dtype=dtype) | ||
case WeightInitialization.RandomUniform: | ||
return jax.random.uniform( | ||
key, | ||
shape, | ||
dtype=dtype, | ||
minval=-1.0, | ||
maxval=1.0, | ||
) | ||
case WeightInitialization.RandomNormal: | ||
return jax.random.normal(key, shape, dtype=dtype) | ||
case WeightInitialization.GlorotUniform: | ||
scaling = (6 / (fan_in + fan_out)) ** 0.5 | ||
return jax.random.uniform( | ||
key, | ||
shape, | ||
dtype=dtype, | ||
minval=-scaling, | ||
maxval=scaling, | ||
) | ||
case WeightInitialization.GlorotNormal: | ||
scaling = (2 / (fan_in + fan_out)) ** 0.5 | ||
return jax.random.normal(key, shape, dtype=dtype) * scaling | ||
case WeightInitialization.HeUniform: | ||
scaling = (6 / fan_in) ** 0.5 | ||
return jax.random.uniform( | ||
key, | ||
shape, | ||
dtype=dtype, | ||
minval=-scaling, | ||
maxval=scaling, | ||
) | ||
case WeightInitialization.HeNormal: | ||
scaling = (2 / fan_in) ** 0.5 | ||
return jax.random.normal(key, shape, dtype=dtype) * scaling | ||
|
||
|
||
class JaximalModule(Jaximal): | ||
@staticmethod | ||
def init_state(key: PRNGKeyArray, *args: Any, **kwargs: Any) -> 'JaximalModule': | ||
raise NotImplementedError('This method must be implemented on each subclass') | ||
|
||
def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree: | ||
raise NotImplementedError('This method must be implemented on each subclass') | ||
|
||
|
||
class Linear(JaximalModule): | ||
in_dim: Static[int] | ||
out_dim: Static[int] | ||
|
||
weights: Float[Array, 'in_dim out_dim'] | ||
biases: Float[Array, 'out_dim'] | ||
|
||
@staticmethod | ||
def init_state( | ||
key: PRNGKeyArray, | ||
in_dim: int, | ||
out_dim: int, | ||
weight_initialization: WeightInitialization = WeightInitialization.GlorotUniform, | ||
bias_initialization: WeightInitialization = WeightInitialization.Zero, | ||
) -> 'Linear': | ||
w_key, b_key = jax.random.split(key) | ||
weights = weight_initialization.init(w_key, (in_dim, out_dim), in_dim, out_dim) | ||
biases = weight_initialization.init(b_key, (out_dim,), 1, out_dim) | ||
|
||
return Linear(in_dim, out_dim, weights, biases) | ||
|
||
def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree: | ||
return data @ self.weights + self.biases | ||
|
||
|
||
class Sequential(JaximalModule): | ||
modules: list[JaximalModule] | ||
|
||
@staticmethod | ||
def init_state( | ||
key: PRNGKeyArray, partials: list[Callable[[PRNGKeyArray], JaximalModule]] | ||
) -> 'Sequential': | ||
keys = jax.random.split(key, len(partials)) | ||
|
||
modules = list(partial(key) for key, partial in zip(keys, partials)) | ||
return Sequential(modules) | ||
|
||
def __call__(self, data: PyTree, *args: dict[str, Any], **_: Any) -> PyTree: | ||
assert len(args) == len(self.modules), ( | ||
'Expected `self.modules` and `args` to have the same length ' | ||
f'but got {len(self.modules)} and {len(args)}, respectively.' | ||
) | ||
for kwargs, modules in zip(args, self.modules): | ||
data = modules(data, **kwargs) | ||
|
||
return data | ||
|
||
|
||
__all__ = ['JaximalModule', 'WeightInitialization', 'Linear', 'Sequential'] |
Oops, something went wrong.