From 94ee121c7802226648b563269fe6ae3a14c037bc Mon Sep 17 00:00:00 2001 From: Arvin Kushwaha Date: Fri, 26 Jul 2024 15:34:48 +0200 Subject: [PATCH] Reorganized serialization and added basic NN features --- src/jaximal/__init__.py | 3 + src/jaximal/core.py | 140 +--------------------- src/jaximal/io.py | 23 +++- src/jaximal/nn.py | 125 ++++++++++++++++++++ src/jaximal/serialization.py | 220 +++++++++++++++++++++++++++++++++++ tests/test_core.py | 3 +- tests/test_serialization.py | 63 ++++++++++ 7 files changed, 433 insertions(+), 144 deletions(-) create mode 100644 src/jaximal/nn.py create mode 100644 src/jaximal/serialization.py create mode 100644 tests/test_serialization.py diff --git a/src/jaximal/__init__.py b/src/jaximal/__init__.py index c0263c8..5c327c8 100644 --- a/src/jaximal/__init__.py +++ b/src/jaximal/__init__.py @@ -1 +1,4 @@ from . import core as core +from .core import Jaximal as Jaximal +from .core import Static as Static +from .nn import JaximalModule as JaximalModule diff --git a/src/jaximal/core.py b/src/jaximal/core.py index 520aaa2..70f5482 100644 --- a/src/jaximal/core.py +++ b/src/jaximal/core.py @@ -1,16 +1,8 @@ -import typing - from dataclasses import dataclass, fields -from itertools import chain -from json import dumps, loads from typing import ( TYPE_CHECKING, Annotated, - Any, - Iterable, - Mapping, Self, - Sequence, cast, dataclass_transform, get_origin, @@ -121,134 +113,4 @@ def cls_eq(self: Self, other: object) -> bool: jax.tree_util.register_dataclass(cls, data_fields, meta_fields) -def dictify( - x: Any, - prefix: str = '', - typ: type | None = None, -) -> tuple[dict[str, Array], dict[str, str]]: - """ - Given an object, a prefix, and optionally a type for the object, attempt to - deconstruct the object into a `dict[str, jax.Array]` and a `dict[str, str]` - where all keys have the given prefix. - """ - - typ = type(x) if typ is None else typ - - data: dict[str, Array] = {} - metadata: dict[str, str] = {} - - if get_origin(typ) == Static: - metadata |= {prefix.removesuffix('::'): dumps(x)} - - elif isinstance(x, Array): - data |= {prefix.removesuffix('::'): x} - - elif issubclass(typ, Jaximal): - for child_key, child_type in x.__annotations__.items(): - child_data, child_metadata = dictify( - getattr(x, child_key), prefix + child_key + '::', typ=child_type - ) - - data |= child_data - metadata |= child_metadata - - elif isinstance(x, Mapping): - for child_key, child_elem in x.items(): - child_data, child_metadata = dictify( - child_elem, prefix + str(child_key) + '::' - ) - - data |= child_data - metadata |= child_metadata - - elif isinstance(x, Sequence): - for child_idx, child_elem in enumerate(x): - child_data, child_metadata = dictify( - child_elem, prefix + str(child_idx) + '::' - ) - - data |= child_data - metadata |= child_metadata - - else: - raise TypeError( - f'Unexpected type {typ} and prefix {prefix} recieved by `dictify`.' - ) - - return data, metadata - - -def dedictify[T]( - typ: type[T], - data: dict[str, Array], - metadata: dict[str, str], - prefix: str = '', -) -> T: - """ - Given a type, a `dict[str, jax.Array]`, a `dict[str, str]`, and a prefix - for the dictionary keys, attempts to recreate an instance of the given - type. - """ - - base_typ = get_origin(typ) - if base_typ is None: - base_typ = typ - - if get_origin(typ) == Static: - return loads(metadata[prefix.removesuffix('::')]) - - elif typ == Array or issubclass(base_typ, AbstractArray): - return cast(T, data[prefix.removesuffix('::')]) - - elif issubclass(base_typ, Jaximal): - children = {} - for child_key, child_type in typ.__annotations__.items(): - children[child_key] = dedictify( - child_type, data, metadata, prefix + child_key + '::' - ) - - return typ(**children) - - elif issubclass(base_typ, Mapping): - children = {} - key_type, child_type = typing.get_args(typ) - - for keys in filter(lambda x: x.startswith(prefix), data): - keys = keys[len(prefix) :] - child_key = key_type(keys.split('::', 1)[0]) - child_prefix = prefix + str(child_key) + '::' - - if child_key in children: - continue - - children[child_key] = dedictify(child_type, data, metadata, child_prefix) - - return cast(T, children) - - elif issubclass(base_typ, list): - children = [] - (child_type,) = typing.get_args(typ) - - child_idx = 0 - while True: - child_prefix = prefix + str(child_idx) + '::' - try: - next( - filter( - lambda x: x.startswith(child_prefix), - cast(Iterable[str], chain(data.keys(), metadata.keys())), - ) - ) - except StopIteration: - break - children.append(dedictify(child_type, data, metadata, child_prefix)) - child_idx += 1 - - return cast(T, children) - - raise TypeError( - f'Unexpected type {typ} and prefix {prefix} recieved by `dedictify`.' - ) - - -__all__ = ['Jaximal', 'Static', 'dictify', 'dedictify'] +__all__ = ['Jaximal', 'Static'] diff --git a/src/jaximal/io.py b/src/jaximal/io.py index 9cbc224..1473d67 100644 --- a/src/jaximal/io.py +++ b/src/jaximal/io.py @@ -5,6 +5,8 @@ from jaxtyping import Array +from jaximal.serialization import JSONEncoder, json_object_hook + def save_file( filename: str, @@ -16,7 +18,16 @@ def save_file( str]` called `meta`, uses `safetensors.flax.save_file` to save both to the given `filename`. """ - safflax.save_file(data, filename, metadata) + + if data: + safflax.save_file(data, filename, metadata) + else: + with open(filename, 'wb') as f: + metadata_ser = json.dumps( + {'__metadata__': metadata, '__no_data__': True}, cls=JSONEncoder + ) + f.write(struct.pack(' tuple[dict[str, Array], dict[str, str]]: @@ -25,12 +36,14 @@ def load_file(filename: str) -> tuple[dict[str, Array], dict[str, str]]: from the given `filename` and then manually retrieves the `dict[str, str]` metadata from the file. """ - data = safflax.load_file(filename) with open(filename, 'rb') as f: header_len = struct.unpack(' tuple[dict[str, Array], dict[str, str]]: data = safflax.load(raw_data) header_len = struct.unpack(' Float[Array, '*']: + match self: + case WeightInitialization.Zero: + return np.zeros(shape, dtype=dtype) + case WeightInitialization.RandomUniform: + return jax.random.uniform( + key, + shape, + dtype=dtype, + minval=-1.0, + maxval=1.0, + ) + case WeightInitialization.RandomNormal: + return jax.random.normal(key, shape, dtype=dtype) + case WeightInitialization.GlorotUniform: + scaling = (6 / (fan_in + fan_out)) ** 0.5 + return jax.random.uniform( + key, + shape, + dtype=dtype, + minval=-scaling, + maxval=scaling, + ) + case WeightInitialization.GlorotNormal: + scaling = (2 / (fan_in + fan_out)) ** 0.5 + return jax.random.normal(key, shape, dtype=dtype) * scaling + case WeightInitialization.HeUniform: + scaling = (6 / fan_in) ** 0.5 + return jax.random.uniform( + key, + shape, + dtype=dtype, + minval=-scaling, + maxval=scaling, + ) + case WeightInitialization.HeNormal: + scaling = (2 / fan_in) ** 0.5 + return jax.random.normal(key, shape, dtype=dtype) * scaling + + +class JaximalModule(Jaximal): + @staticmethod + def init_state(key: PRNGKeyArray, *args: Any, **kwargs: Any) -> 'JaximalModule': + raise NotImplementedError('This method must be implemented on each subclass') + + def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree: + raise NotImplementedError('This method must be implemented on each subclass') + + +class Linear(JaximalModule): + in_dim: Static[int] + out_dim: Static[int] + + weights: Float[Array, 'in_dim out_dim'] + biases: Float[Array, 'out_dim'] + + @staticmethod + def init_state( + key: PRNGKeyArray, + in_dim: int, + out_dim: int, + weight_initialization: WeightInitialization = WeightInitialization.GlorotUniform, + bias_initialization: WeightInitialization = WeightInitialization.Zero, + ) -> 'Linear': + w_key, b_key = jax.random.split(key) + weights = weight_initialization.init(w_key, (in_dim, out_dim), in_dim, out_dim) + biases = weight_initialization.init(b_key, (out_dim,), 1, out_dim) + + return Linear(in_dim, out_dim, weights, biases) + + def __call__(self, data: PyTree, **kwargs: Static[Any]) -> PyTree: + return data @ self.weights + self.biases + + +class Sequential(JaximalModule): + modules: list[JaximalModule] + + @staticmethod + def init_state( + key: PRNGKeyArray, partials: list[Callable[[PRNGKeyArray], JaximalModule]] + ) -> 'Sequential': + keys = jax.random.split(key, len(partials)) + + modules = list(partial(key) for key, partial in zip(keys, partials)) + return Sequential(modules) + + def __call__(self, data: PyTree, *args: dict[str, Any], **_: Any) -> PyTree: + assert len(args) == len(self.modules), ( + 'Expected `self.modules` and `args` to have the same length ' + f'but got {len(self.modules)} and {len(args)}, respectively.' + ) + for kwargs, modules in zip(args, self.modules): + data = modules(data, **kwargs) + + return data + + +__all__ = ['JaximalModule', 'WeightInitialization', 'Linear', 'Sequential'] diff --git a/src/jaximal/serialization.py b/src/jaximal/serialization.py new file mode 100644 index 0000000..0cbf03c --- /dev/null +++ b/src/jaximal/serialization.py @@ -0,0 +1,220 @@ +import base64 +import json +import pickle + +from itertools import chain +from typing import ( + Any, + Callable, + Iterable, + Mapping, + Sequence, + cast, + get_args, + get_origin, +) + +import jax + +from jaxtyping import AbstractArray, Array + +from jaximal.core import Jaximal, Static + + +class FnRegistry: + functions = {} + inv_functions = {} + + def add(self, function: Callable[..., Any], name: str): + if function in self.functions or name in self.inv_functions: + raise ValueError( + f'Function {function} with name {name} already in registry.' + ) + + self.functions[function] = name + self.inv_functions[name] = function + + def lookup_name(self, name: str) -> Callable[..., Any] | None: + return self.inv_functions.get(name) + + def lookup_function(self, function: Callable[..., Any]) -> str | None: + return self.functions.get(function) + + +global_fn_registry = FnRegistry() + + +global_fn_registry.add(jax.numpy.sin, 'jax.numpy.sin') +global_fn_registry.add(jax.numpy.cos, 'jax.numpy.cos') +global_fn_registry.add(jax.numpy.tan, 'jax.numpy.tan') +global_fn_registry.add(jax.numpy.log, 'jax.numpy.log') +global_fn_registry.add(jax.numpy.exp, 'jax.numpy.exp') +global_fn_registry.add(jax.numpy.tanh, 'jax.numpy.tanh') + + +class JSONEncoder(json.JSONEncoder): + def default(self, o: Any): + if isinstance(o, Callable): + if o_str := global_fn_registry.lookup_function(o): + return { + 'callable': True, + 'jax_map': o_str, + } + + else: + return { + 'callable': True, + 'code': base64.b64encode(pickle.dumps(o)).decode('utf-8'), + } + + +def json_object_hook(dct: Any): + if 'callable' in dct: + if 'jax_map' in dct: + return global_fn_registry.lookup_name(dct['jax_map']) + elif 'code' in dct: + return pickle.loads(base64.b64decode(dct['code'])) + else: + return dct + return dct + + +def dictify( + x: Any, + prefix: str = '', + typ: type | None = None, +) -> tuple[dict[str, Array], dict[str, str]]: + """ + Given an object, a prefix, and optionally a type for the object, attempt to + deconstruct the object into a `dict[str, jax.Array]` and a `dict[str, str]` + where all keys have the given prefix. + """ + + typ = type(x) if typ is None else typ + + data: dict[str, Array] = {} + metadata: dict[str, str] = {} + + if get_origin(typ) == Static: + metadata |= {prefix.removesuffix('::'): json.dumps(x, cls=JSONEncoder)} + + elif isinstance(x, Array): + data |= {prefix.removesuffix('::'): x} + + elif issubclass(typ, Jaximal): + for child_key, child_type in x.__annotations__.items(): + child_data, child_metadata = dictify( + getattr(x, child_key), prefix + child_key + '::', typ=child_type + ) + + data |= child_data + metadata |= child_metadata + + elif isinstance(x, Mapping): + for child_key, child_elem in x.items(): + child_data, child_metadata = dictify( + child_elem, prefix + str(child_key) + '::' + ) + + data |= child_data + metadata |= child_metadata + + elif isinstance(x, Sequence): + for child_idx, child_elem in enumerate(x): + child_data, child_metadata = dictify( + child_elem, prefix + str(child_idx) + '::' + ) + + data |= child_data + metadata |= child_metadata + + else: + raise TypeError( + f'Unexpected type {typ} and prefix {prefix} recieved by `dictify`.' + ) + + return data, metadata + + +def dedictify[T]( + typ: type[T], + data: dict[str, Array], + metadata: dict[str, str], + prefix: str = '', +) -> T: + """ + Given a type, a `dict[str, jax.Array]`, a `dict[str, str]`, and a prefix + for the dictionary keys, attempts to recreate an instance of the given + type. + """ + + base_typ = get_origin(typ) + if base_typ is None: + base_typ = typ + + if get_origin(typ) == Static: + return json.loads( + metadata[prefix.removesuffix('::')], object_hook=json_object_hook + ) + + elif typ == Array or issubclass(base_typ, AbstractArray): + return cast(T, data[prefix.removesuffix('::')]) + + elif issubclass(base_typ, Jaximal): + children = {} + for child_key, child_type in typ.__annotations__.items(): + children[child_key] = dedictify( + child_type, data, metadata, prefix + child_key + '::' + ) + + return typ(**children) + + elif issubclass(base_typ, Mapping): + children = {} + key_type, child_type = get_args(typ) + + for keys in filter(lambda x: x.startswith(prefix), data): + keys = keys[len(prefix) :] + child_key = key_type(keys.split('::', 1)[0]) + child_prefix = prefix + str(child_key) + '::' + + if child_key in children: + continue + + children[child_key] = dedictify(child_type, data, metadata, child_prefix) + + return cast(T, children) + + elif issubclass(base_typ, list): + children = [] + (child_type,) = get_args(typ) + + child_idx = 0 + while True: + child_prefix = prefix + str(child_idx) + '::' + try: + next( + filter( + lambda x: x.startswith(child_prefix), + cast(Iterable[str], chain(data.keys(), metadata.keys())), + ) + ) + except StopIteration: + break + children.append(dedictify(child_type, data, metadata, child_prefix)) + child_idx += 1 + + return cast(T, children) + + raise TypeError( + f'Unexpected type {typ} and prefix {prefix} recieved by `dedictify`.' + ) + + +__all__ = [ + 'dictify', + 'dedictify', + 'json_object_hook', + 'JSONEncoder', + 'global_fn_registry', +] diff --git a/tests/test_core.py b/tests/test_core.py index bd6e116..4deb91f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,7 +6,8 @@ import optax from jax import numpy as np -from jaximal.core import Jaximal, Static, dedictify, dictify +from jaximal.core import Jaximal, Static +from jaximal.serialization import dedictify, dictify from jaximal.io import load_file, save_file from jaxtyping import Array, Float, PRNGKeyArray, Scalar diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 0000000..825d0fc --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,63 @@ +import os.path + +from typing import Callable + +import jax + +from jax import numpy as np +from jaximal.core import Jaximal, Static +from jaximal.io import load_file, save_file +from jaximal.serialization import dedictify, dictify +from jaxtyping import Array + + +def activation_function(x: Array, data: Array) -> Array: + return np.sin(x + data) + + +def test_serialization(tmp_path: str): + class Activation(Jaximal): + function: Static[Callable[[Array], Array]] + + def forward(self, x: Array) -> Array: + return self.function(x) + + class ActivationWithData(Jaximal): + function: Static[Callable[[Array, Array], Array]] + data: Array + + def forward(self, x: Array) -> Array: + return self.function(x, self.data) + + key = jax.random.key(0) + x = jax.random.uniform(key, (1024,)) + + activation = Activation(np.sin) + save_file(os.path.join(tmp_path, 'test_mlp.safetensors'), *dictify(activation)) + + activation_restored = dedictify( + Activation, + *load_file(os.path.join(tmp_path, 'test_mlp.safetensors')), + ) + + assert activation_restored == activation + assert np.allclose(activation.forward(x), activation_restored.forward(x)) + + activation_w_data = ActivationWithData(activation_function, np.array(1.0)) + save_file( + os.path.join(tmp_path, 'test_mlp.safetensors'), *dictify(activation_w_data) + ) + + activation_w_data_restored = dedictify( + ActivationWithData, + *load_file(os.path.join(tmp_path, 'test_mlp.safetensors')), + ) + + assert activation_w_data_restored == activation_w_data + assert np.allclose( + activation_w_data.forward(x), activation_w_data_restored.forward(x) + ) + + +if __name__ == '__main__': + test_serialization('.')