LapJAX is a JAX-based python package that accelerates laplacian computation automatically using a technique called "Forward Laplacian". It leverages the properties of forward computation mode and the sparsity of various operators, and significantly accelerates computation in proportion to system sizes. Notably, LapJAX acceleration is in loss of no precision. For more details including mathematical derivitations, please refer to the forward laplacian paper that uses LapJAX to accelerate Quantum Monte Carlo algorithm.
WIP Repo: This repo is still under development. We sincerely thank you in advance for your understanding, and appreciate any feedback, suggestions, and contributions. See contributing guidelines for more.
To install LapJAX together with all the dependencies (excluding JAX), go to the root directory and run
pip install .
Notice that you are required to pre-install an appropriate version of JAX before running the command above. Currently, we do not support -e
option (editable installation).
GPU Support: To use LapJAX on GPU, you only need to ensure the installed jax
supports GPU computation. For example, you can run:
pip install --upgrade jax[cuda]==0.3.24 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Please see the JAX documentation for more details.
Post Installation Check: To check whether LapJAX is successfully installed, you can run
python test/test_funcs.py
We highly recommend going through the Jupyter notebook lapjax-tutorial.ipynb
to get a better understanding of LapJAX. Here we provide a brief introduction to how to use LapJAX. Assume you have written a function f(x)
using JAX where x
is a jax.Array
. To compute the laplacian of f(x)
w.r.t. x
, i.e.
- Replace
jax
withlapjax
, e.g.,
# import jax.numpy as jnp
# from jax import vmap
import lapjax.numpy as jnp
from lapjax import vmap
def f(x):
# Your code here.
pass
- Create a LapTuple from the input, pass it to the function directly, and obtain the laplacian, e.g.
from lapjax import LapTuple, TupType
# input_x is the input of f(x)
def lap_of_f(x):
input_x = LapTuple(input_x, is_input = True)
output = f(input_x)
return output.get(TupType.LAP)
A Quick Look: LapJAX explicitly wraps various JAX functions such that they can take a data structure called LapTuple as input. Here, LapTuple is a tuple-like structure containing
Attention: When the input is a standard JAX Array, LapJAX operators behave the same as JAX operators. In addition, JAX and LapJAX are compatible. That said, to avoid unexpected problems, you are still recommended to replace import commands only in Python files that contain laplacian computation functions.
JAX-based Packages: Many popular packages rely on JAX, and we understand that they are widely used. Currently, we cannot wrap the functions in these packages directly, but there is a replacement trick to make LapJAX compatible with these packages. Before importing any JAX-based packages, you can run
import lapjax
import sys
sys.modules['jax'] = lapjax
This replaces the system module jax
with lapjax
. Other packages importing jax
ends up relying on lapjax
, which handles LapTuple inputs. Notice that it is possible this method could cause additional unexpected issues, for example, if they apply some non-JAX functions on JAX arrays. Please raise issues if you have any problems.
LapJax wraps many commonly used JAX operators that support LapTuple inputs (most of which are in jax.numpy
). However, we acknowledge that the supported functions are a still small subset. If you use an unwrapped function with LapTuple inputs (say isnan
), you will get a NotImplementedError:
NotImplementedError: Lapjax encounters unwrapped function 'isnan'.
Please consider using other functions or wrap it yourself.
You can refer to README for more information about customized wrap.
You can always use lapjax.is_wrapped(f)
to check whether JAX function f
is supported, or you can get all supported functions via lapjax.get_all_wrapped_names()
. Below is the list of supported functions (jnp
stands for jax.numpy
, and jlax
stands for jax.lax
). To understand the classification of these functions, please refer to the tutorial.
FType.CONSTRUCTION: [
jnp.shape, jnp.size, jnp.eye, jnp.array, jnp.ones, jnp.ones_like, jnp.zeros, jnp.zeros_like, jnp.asarray, jnp.sign, jnp.argmax, jnp.argmin, jnp.argsort, jnp.argwhere, jlax.stop_gradient,
],
FType.LINEAR: [
jnp.reshape, jnp.transpose, jnp.swapaxes, jnp.split, jnp.array_split, jnp.concatenate, jnp.stack, jnp.squeeze, jnp.expand_dims, jnp.repeat, jnp.tile, jnp.where, jnp.triu, jnp.tril, jnp.sum, jnp.mean, jnp.broadcast_to,
],
FType.ELEMENT: [
jnp.sin, jnp.cos, jnp.tan, jnp.arcsin, jnp.arccos, jnp.arctan, jnp.arcsinh, jnp.arccosh, jnp.arctanh, jnp.sinh, jnp.cosh, jnp.tanh, jnp.exp, jnp.log, jnp.exp2, jnp.log2, jnp.square, jnp.sqrt, jnp.power, jnp.abs, jnp.absolute, jlax.sin, jlax.cos, jlax.tan, jlax.asin, jlax.acos, jlax.atan, jlax.asinh, jlax.acosh, jlax.atanh, jlax.exp, jlax.log, jlax.square, jlax.sqrt, jlax.rsqrt, jlax.pow, jlax.integer_pow, jlax.abs,
],
FType.OVERLOAD: [
jnp.add, jnp.subtract, jnp.multiply, jnp.divide, jnp.true_divide
],
FType.MERGING: [
jnp.linalg.norm, jnp.prod,
],
FType.CUSTOMIZED: [
jnp.matmul, jnp.dot, jnp.max, jnp.min, jnp.amax, jnp.amin, jnp.linalg.slogdet, jax.nn.logsumexp, jax.nn.softmax,
],
FType.EMPTY: [
jax.vmap,
],
When you want to use an unwrapped function, a good starting point is to consider replacing it with wrapped functions. That said, wrapping new functions yourself is still necessary in many cases. We give a few examples of how to easily support other JAX functions with LapTuple inputs in lapjax-tutorial.ipynb
. Here we provide a quick introduction:
- Check whether the function can be classified as an existing function FType. For example, for construction functions, linear functions, element-wise functions and merging functions, LapJAX provides a general handler, and you can directly run
from lapjax import numpy, custom_wrap, FType
custom_wrap(numpy.isnan, FType.CONSTRUCTION)
- If the function does not belong to any of the four FTypes above, you need to write a derivative computation and laplacian computation code yourself, and use
custom_wrap
to bind it to the function.
Checking Correctness: To automatically check the correctness of functions you wrap, please use lapjax.create_check_function
. For example, to check the correctness of lapjax.numpy.sin
, you can run
import time
import lapjax
wrap_func = lapjax.numpy.sin # assume you have wrapped it
test_input = lapjax.numpy.ones([3,5])
checker = lapjax.create_check_function(
wrap_func, derivative_args=(0,), seed=42)
grad_diff, lap_diff = checker(test_input)
print(f'gradient difference: {grad_diff}, laplacian difference: {lap_diff}')
Here, checker
takes standard jax.Array
inputs (same as jax.numpy.sin
), and we explicitly convert it into LapTuple via the argument derivative_args
that specifies which input arguments should be converted. Currently, we do not support any convertion of keyword arguments. Similarly, you can specify which outputs you want to check using derivative_outputs
(default is (0,)
). For example, lapjax.numpy.linalg.slogdet
outputs (sign, logdet)
, and you should set derivative_outputs=(1,)
.
We sincerely thank you for any feedback, suggestions, and contributions. The following contributions are especially welcome:
- Support more JAX operators with LapTuples as inputs, especially functions in
jax.numpy
andjax.lax
. - Support JAX built-in functions with LapTuples as inputs, e.g.,
jax.grad
andjax.jacobian
. Notice that althoughjax.grad
cannot apply to functions that take LapTuples as inputs, you are free to create a function that maps input to its laplacian, and applyjax.grad
to this function. This means you can usejax.grad
to compute the gradient of the laplacian of a function. - Better cooperation with JAX and other JAX-based packages like
haiku
. For now, you have to usesys.modules
replacement trick to support JAX-based operators inhaiku
. We are looking forward to a better solution. - Any other contributions that can improve the performance or the usability of LapJAX!
If you find this repo useful, please cite our paper:
@article{li2023forward,
title={Forward Laplacian: A New Computational Framework for Neural Network-based Variational Monte Carlo},
author={Li, Ruichen and Ye, Haotian and Jiang, Du and Wen, Xuelan and Wang, Chuwei and Li, Zhe and Li, Xiang and He, Di and Chen, Ji and Ren, Weiluo and Wang, Liwei},
journal={arXiv preprint arXiv:2307.08214},
year={2023}
}