The pun is very intended.
pip install "git+https://github.com/jhn-nt/data-snax.git"
For the nightly install:
pip install "git+https://github.com/jhn-nt/data-snax.git@experimental"
snax
(pronounced snacs) is a lightweight python library written in jax to accomodate jax-based model data injestion needs.
Developed with jax
in mind, it offers a simple interface to quickly plug, transform, combine and batch data. Infact, snax
leverages jax
primitives thus inherting its lightning speed.
More than a necessity, snax
aims at improving general quality of life when working with data loaders in jax
, citing the docs:
"There are already a lot of great data loaders out there, so let’s just use them instead of reinventing anything."
However, to naively use tf.data.Dataset
in conjuction with jax
as suggested by the devs may raise some nuisances, particularly:
- Memory conflicts: both
tensorflow
andjax
try to allocate all available vRAM, which my cause one or the other not work properly. - Type casting: All tensors should be cast to the appropriate equivalent type in
jax
.
Again, these issues may be avoided with some research on stack overflow, morever tf.data.Dataset
was developed by the tensorflow
team and had been tested thorugly on many scenarios.
But for tasks that can run on your machine, snax
is a quick alternative to dive in jax
without worrying about unneccessary details.
If you are familiar with how tf.data works then there is not much more for you to learn:
snax
aims to replicate as closely as possible the API and behavior of tf.data
in order to maintain compatibility with its legacy.
Instead of tf.data.Dataset
there is snax.data.Dataset
, instead of tf.data.Dataset.from_tensor_slices
there is snax.data.Dataset.from_tensor_slices
et cetera.
import snax as sn
import tensorflow as tf
X_y=(X,y) # some random dataset
# creating a dataset from tensors
sn_dataset=sn.data.Dataset.from_tensor_slices(X_y) # snax
tf_dataset=tf.data.Dataset.from_tensor_slices(X_y) # tensorflow
# applying transformations
sn_dataset=sn_dataset.map(lambda x,y: (x**2,y)).skip(10000) # snax
tf_dataset=tf_dataset.map(lambda x,y: (x**2,y)).skip(10000) # tensorflow
# consuming the data
for batch in sn_dataset.batch(32): # snax
model.fit(batch)
for batch in tf_dataset.batch(32): # tensorflow
model.fit(batch)
If not familir with these APIs I suggest to read through the tensorflow tutorials first, which are exceptionally well documented.
Additionally snax
also supports a limited portfolio of datasets offered via tensorflow_datasets
in its snax.datasets
module which I hope to extend in the future.
import snax.datasets as snds
import tensorflow_datasets as tfds
# loading a dataset
sn_dataset=snds.load("mnist") # snax
tf_dataset=tfds.load("mnist") # tensorflow
jax
introduces several new features, amongst which the abilty to jit functions.
jax.jit() particularly shines (more here) when applied to complex functions that needs to be called many times, for example with complex models or for transformations applied to data batches.
snax.data.Dataset
supports jax.jit
via the jit()
method which jits all mapped transformations in one single encapsulation hence increasing by a significant margin the rate of data injestion.
This feature is most useful when applied after one or more map
calls.
import snax as sn
import snax.datasets as snds
import flax.linen as nn
import jax.numpy as jnp
ds=snds.load("mnist")["train"]
# defining a preprocessing function with some casting and matrix multiplication
W=jnp.ones((784,256))
def processing(input):
image=input["image"]
x=image.astype("float32")/255
x=jnp.reshape(x,(-1,))
return {"image":nn.relu(jnp.matmul(x,W)),"label":input["label"]}
# creating a dataset from tensors
sn_dataset=ds.map(processing)
jitted_sn_dataset=sn_dataset.jit() # boosting performance
In a small scale testing experiment, the use of snax.jit()
increased speed on average by 64% compared to tf.data.Dataset
and by 75% compared to vanilla snax.data.Dataset
.
snax.data.Dataset
can iterate through batches with impressive speed, even more when leveraging snax.jit()
.
It is to note however that snax.data.Dataset
cannot replace tf.data.Dataset
since it cannot scale as tf.data.Dataset
and it is not as stable as tf.data.Dataset
. Infact, the stability of tf.data.Dataset
is impressive, even when changing batch sizes by an order magnitude the time to iterate over a batch changes on average only by .10 seconds, wow.
In this current version it supports:
map
: behavior is similar to that oftf.data.Dataset.map
take
: behavior is similar to that oftf.data.Dataset.take
skip
: behavior is similar to that oftf.data.Dataset.skip
zip
: behavior is similar to that oftf.data.Dataset.zip
batch
: similar behavior to that oftf.data.Dataset.batch
apply
: similar behavior to that oftf.data.Dataset.apply
shuffle
: behavior is similar totf.data.Dataset.shuffle
but with different input signature, instead of abuffer_size
it requires ajax.random.PRNGKey
.
import snax as sn
import tensorflow as tf
from jax.random import PRNGKey
X_y=(X,y) # some random dataset
# creating a dataset from tensors
sn_dataset=sn.data.Dataset.from_tensor_slices(X_y) # snax
tf_dataset=tf.data.Dataset.from_tensor_slices(X_y) # tensorflow
# difference in `shuffle` behavior
sn_dataset=sn_dataset.shuffle(PRNGKey(0)) # snax
tf_dataset=tf_dataset.shuffle(1024) # tensorflow
What it does not yet support:
from_generator
concatenate
bucket_by_sequence_length
filter
As soon as I came up with the name snax
while thinking about how geniously named is deep-mind rlax, I realized I had to release a jax
injestion package with such a name.
I figured how to do it afterward.
snax
was developed as an excercise to learn bits of the jax
ecosystem and perhaps for the enjoyment of the community.
It does not have any professional ambition nor does it hold any technical or scientific quality to be so.
Feel free to contribute
Cheers