Darren Wilkinson
JAX is a pure functional language embedded in python, but designed to
feel as much like python as practical. From a python
prompt, first do
some imports.
import os
import pandas as pd
import numpy as np
import scipy as sp
import jax
from jax import grad, jit
import jax.numpy as jnp
import jax.scipy as jsp
import jax.lax as jl
If any of these imports fail, you probably don’t have JAX installed
correctly (in your current environment). For most numpy and scipy
functions, there is a JAX equivalent, so with the above imports,
translating from regular python to JAX often involves replacing a call
to np.X
(for some X
) with a call to jnp.X
, and sp.X
with
jsp.X
. But there are other issues to confront, due to the fact that
JAX is a pure functional language and python most definitely isn’t!
v = jnp.array([2, 4, 6, 3]).astype(jnp.float32)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Note that the type of the array has been set to float32
, since these
are fast and efficient, especially on GPUs. JAX arrays are immutable.
v[2]
vu = v.at[2].set(7)
vu
Array([2., 4., 7., 3.], dtype=float32)
v
Array([2., 4., 6., 3.], dtype=float32)
We can map JAX arrays.
jl.map(lambda x: 2*x, v)
Array([ 4., 8., 12., 6.], dtype=float32)
Mapping can be parallelised, and JAX will do this automatically. We can also reduce them.
jl.reduce(v, 0.0, lambda x,y: x+y, [0])
Array(15., dtype=float32)
jnp.sum(v)
Array(15., dtype=float32)
The reduction must be monoidal (the operation must be associative, and the initial value must be an identity wrt that operation), or the result is undefined. Since the reduction is monoidal, it can be parallised via tree reduction, and JAX will do this automatically.
Functions are written like regular python functions. But if they are to be part of a hot loop, they can be JIT-compiled.
@jit
def sumArray1d(v):
return jl.reduce(v, 0.0, lambda x,y: x+y, [0])
float(sumArray1d(v))
15.0
Note that you can’t use float
inside a JIT’d JAX function, since
float
is a python function, not a JAX function.
We have seen that functional languages often exploit recursion, either
explicitly or implicitly, for the implementation of “looping”
constructs. However, allowing general recursion turns out to be
problematic for reverse-mode automatic differentiation. Consequently,
some differentiable functional languages (such as JAX and Dex) disallow
recursive functions. But without any mutable variables or recursion, how
can we loop?! In this case the language must provide us with some
built-in constructs. In JAX, the two most commonly used constructs (in
addition to
map
,
reduce
and
scan
)
are
jax.lax.while_loop
and
jax.lax.fori_loop
.
Note that you cannot reverse-mode differentiate through a while_loop
(this is problematic for the same reason that recursive functions are
problematic - you cannot know statically what the memory requirements
will be). A for loop is relatively straightforward.
def logFactF(n):
return float(jl.fori_loop(1, n+1,
lambda i,acc: acc + jnp.log(i), 0.0))
logFactF(3)
1.7917594909667969
logFactF(100000)
1051299.625
Note that the upper bound on the loop is exclusive. A while loop is slightly more involved, due to the need to propagate two items of state (the counter and the accumulator). However, the while loop can be used when the number of iterations is not known statically.
def logFactW(n):
def cont(state):
[i, acc] = state
return i <= n
def advance(state):
[i, acc] = state
return [i + 1, acc + jnp.log(i)]
return float(jl.while_loop(cont, advance, [1, 0.0])[1])
logFactW(3)
1.7917594909667969
logFactW(100000)
1051299.625