Skip to content

Commit

Permalink
Reorganized serialization and added basic NN features
Browse files Browse the repository at this point in the history
  • Loading branch information
ArvinSKushwaha committed Jul 26, 2024
1 parent ea32d38 commit 94ee121
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 144 deletions.
3 changes: 3 additions & 0 deletions src/jaximal/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from . import core as core
from .core import Jaximal as Jaximal
from .core import Static as Static
from .nn import JaximalModule as JaximalModule
140 changes: 1 addition & 139 deletions src/jaximal/core.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
import typing

from dataclasses import dataclass, fields
from itertools import chain
from json import dumps, loads
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Iterable,
Mapping,
Self,
Sequence,
cast,
dataclass_transform,
get_origin,
Expand Down Expand Up @@ -121,134 +113,4 @@ def cls_eq(self: Self, other: object) -> bool:
jax.tree_util.register_dataclass(cls, data_fields, meta_fields)


def dictify(
x: Any,
prefix: str = '',
typ: type | None = None,
) -> tuple[dict[str, Array], dict[str, str]]:
"""
Given an object, a prefix, and optionally a type for the object, attempt to
deconstruct the object into a `dict[str, jax.Array]` and a `dict[str, str]`
where all keys have the given prefix.
"""

typ = type(x) if typ is None else typ

data: dict[str, Array] = {}
metadata: dict[str, str] = {}

if get_origin(typ) == Static:
metadata |= {prefix.removesuffix('::'): dumps(x)}

elif isinstance(x, Array):
data |= {prefix.removesuffix('::'): x}

elif issubclass(typ, Jaximal):
for child_key, child_type in x.__annotations__.items():
child_data, child_metadata = dictify(
getattr(x, child_key), prefix + child_key + '::', typ=child_type
)

data |= child_data
metadata |= child_metadata

elif isinstance(x, Mapping):
for child_key, child_elem in x.items():
child_data, child_metadata = dictify(
child_elem, prefix + str(child_key) + '::'
)

data |= child_data
metadata |= child_metadata

elif isinstance(x, Sequence):
for child_idx, child_elem in enumerate(x):
child_data, child_metadata = dictify(
child_elem, prefix + str(child_idx) + '::'
)

data |= child_data
metadata |= child_metadata

else:
raise TypeError(
f'Unexpected type {typ} and prefix {prefix} recieved by `dictify`.'
)

return data, metadata


def dedictify[T](
typ: type[T],
data: dict[str, Array],
metadata: dict[str, str],
prefix: str = '',
) -> T:
"""
Given a type, a `dict[str, jax.Array]`, a `dict[str, str]`, and a prefix
for the dictionary keys, attempts to recreate an instance of the given
type.
"""

base_typ = get_origin(typ)
if base_typ is None:
base_typ = typ

if get_origin(typ) == Static:
return loads(metadata[prefix.removesuffix('::')])

elif typ == Array or issubclass(base_typ, AbstractArray):
return cast(T, data[prefix.removesuffix('::')])

elif issubclass(base_typ, Jaximal):
children = {}
for child_key, child_type in typ.__annotations__.items():
children[child_key] = dedictify(
child_type, data, metadata, prefix + child_key + '::'
)

return typ(**children)

elif issubclass(base_typ, Mapping):
children = {}
key_type, child_type = typing.get_args(typ)

for keys in filter(lambda x: x.startswith(prefix), data):
keys = keys[len(prefix) :]
child_key = key_type(keys.split('::', 1)[0])
child_prefix = prefix + str(child_key) + '::'

if child_key in children:
continue

children[child_key] = dedictify(child_type, data, metadata, child_prefix)

return cast(T, children)

elif issubclass(base_typ, list):
children = []
(child_type,) = typing.get_args(typ)

child_idx = 0
while True:
child_prefix = prefix + str(child_idx) + '::'
try:
next(
filter(
lambda x: x.startswith(child_prefix),
cast(Iterable[str], chain(data.keys(), metadata.keys())),
)
)
except StopIteration:
break
children.append(dedictify(child_type, data, metadata, child_prefix))
child_idx += 1

return cast(T, children)

raise TypeError(
f'Unexpected type {typ} and prefix {prefix} recieved by `dedictify`.'
)


__all__ = ['Jaximal', 'Static', 'dictify', 'dedictify']
__all__ = ['Jaximal', 'Static']
23 changes: 19 additions & 4 deletions src/jaximal/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from jaxtyping import Array

from jaximal.serialization import JSONEncoder, json_object_hook


def save_file(
filename: str,
Expand All @@ -16,7 +18,16 @@ def save_file(
str]` called `meta`, uses `safetensors.flax.save_file` to save both to
the given `filename`.
"""
safflax.save_file(data, filename, metadata)

if data:
safflax.save_file(data, filename, metadata)
else:
with open(filename, 'wb') as f:
metadata_ser = json.dumps(
{'__metadata__': metadata, '__no_data__': True}, cls=JSONEncoder
)
f.write(struct.pack('<Q', len(metadata_ser)))
f.write(metadata_ser.encode('utf-8'))


def load_file(filename: str) -> tuple[dict[str, Array], dict[str, str]]:
Expand All @@ -25,12 +36,14 @@ def load_file(filename: str) -> tuple[dict[str, Array], dict[str, str]]:
from the given `filename` and then manually retrieves the `dict[str, str]`
metadata from the file.
"""
data = safflax.load_file(filename)

with open(filename, 'rb') as f:
header_len = struct.unpack('<Q', f.read(8))[0]
metadata = json.loads(f.read(header_len))['__metadata__']
json_data = f.read(header_len)
deser_data = json.loads(json_data, object_hook=json_object_hook)
metadata = deser_data['__metadata__']

data = {} if '__no_data__' in deser_data else safflax.load_file(filename)
return data, metadata


Expand All @@ -52,7 +65,9 @@ def load(raw_data: bytes) -> tuple[dict[str, Array], dict[str, str]]:
data = safflax.load(raw_data)

header_len = struct.unpack('<Q', raw_data[:8])[0]
metadata = json.loads(raw_data[8 : 8 + header_len])['__metadata__']
metadata = json.loads(raw_data[8 : 8 + header_len], object_hook=json_object_hook)[
'__metadata__'
]

return data, metadata

Expand Down
125 changes: 125 additions & 0 deletions src/jaximal/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from enum import Enum, auto
from typing import Any, Callable

import jax

from jax import numpy as np
from jaxtyping import Array, Float, PRNGKeyArray, PyTree

from jaximal.core import Jaximal, Static


class WeightInitialization(Enum):
Zero = auto()
RandomUniform = auto()
RandomNormal = auto()
GlorotUniform = auto()
GlorotNormal = auto()
HeUniform = auto()
HeNormal = auto()

def init(
self,
key: PRNGKeyArray,
shape: tuple[int, ...],
fan_in: int,
fan_out: int,
dtype: np.dtype = np.float_,
) -> Float[Array, '*']:
match self:
case WeightInitialization.Zero:
return np.zeros(shape, dtype=dtype)
case WeightInitialization.RandomUniform:
return jax.random.uniform(
key,
shape,
dtype=dtype,
minval=-1.0,
maxval=1.0,
)
case WeightInitialization.RandomNormal:
return jax.random.normal(key, shape, dtype=dtype)
case WeightInitialization.GlorotUniform:
scaling = (6 / (fan_in + fan_out)) ** 0.5
return jax.random.uniform(
key,
shape,
dtype=dtype,
minval=-scaling,
maxval=scaling,
)
case WeightInitialization.GlorotNormal:
scaling = (2 / (fan_in + fan_out)) ** 0.5
return jax.random.normal(key, shape, dtype=dtype) * scaling
case WeightInitialization.HeUniform:
scaling = (6 / fan_in) ** 0.5
return jax.random.uniform(
key,
shape,
dtype=dtype,
minval=-scaling,
maxval=scaling,
)
case WeightInitialization.HeNormal:
scaling = (2 / fan_in) ** 0.5
return jax.random.normal(key, shape, dtype=dtype) * scaling


class JaximalModule(Jaximal):
@staticmethod
def init_state(key: PRNGKeyArray, *args: Any, **kwargs: Any) -> 'JaximalModule':
raise NotImplementedError('This method must be implemented on each subclass')

def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree:
raise NotImplementedError('This method must be implemented on each subclass')


class Linear(JaximalModule):
in_dim: Static[int]
out_dim: Static[int]

weights: Float[Array, 'in_dim out_dim']
biases: Float[Array, 'out_dim']

@staticmethod
def init_state(
key: PRNGKeyArray,
in_dim: int,
out_dim: int,
weight_initialization: WeightInitialization = WeightInitialization.GlorotUniform,
bias_initialization: WeightInitialization = WeightInitialization.Zero,
) -> 'Linear':
w_key, b_key = jax.random.split(key)
weights = weight_initialization.init(w_key, (in_dim, out_dim), in_dim, out_dim)
biases = weight_initialization.init(b_key, (out_dim,), 1, out_dim)

return Linear(in_dim, out_dim, weights, biases)

def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree:
return data @ self.weights + self.biases


class Sequential(JaximalModule):
modules: list[JaximalModule]

@staticmethod
def init_state(
key: PRNGKeyArray, partials: list[Callable[[PRNGKeyArray], JaximalModule]]
) -> 'Sequential':
keys = jax.random.split(key, len(partials))

modules = list(partial(key) for key, partial in zip(keys, partials))
return Sequential(modules)

def __call__(self, data: PyTree, *args: dict[str, Any], **_: Any) -> PyTree:
assert len(args) == len(self.modules), (
'Expected `self.modules` and `args` to have the same length '
f'but got {len(self.modules)} and {len(args)}, respectively.'
)
for kwargs, modules in zip(args, self.modules):
data = modules(data, **kwargs)

return data


__all__ = ['JaximalModule', 'WeightInitialization', 'Linear', 'Sequential']
Loading

0 comments on commit 94ee121

Please sign in to comment.