Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make evaluation of Binning corrections JAX-traceable #42

Open
eguiraud opened this issue Nov 4, 2023 · 2 comments
Open

Make evaluation of Binning corrections JAX-traceable #42

eguiraud opened this issue Nov 4, 2023 · 2 comments

Comments

@eguiraud
Copy link
Owner

eguiraud commented Nov 4, 2023

The evaluation of a scipy.interpolate.CubicSpline is not traceable by JAX (because of an explicit cast to np.array somewhere in there).

This in turn makes it impossible to pass CorrectionWithGrad.evaluate to jax.jit or jax.vmap if a Binning correction is involved.

  • for simple Binning (1D histos with scalar bin contents) see below.
  • for compound Binning (1D histos with Formulas or FormulaRefs as bin contents) there is the additional problem that JAX cannot trace through the bin look-up, and I'm not sure how to fix this
  • for MultiBinning (ND histos) I don't know of a jax-friendly implementation of a bin look-up differentiable relaxation, we might have to come up with one. MultiBinning is not supported at the moment anyways, tracked in Add support for MultiBinning #15
@eguiraud eguiraud changed the title jax.jit does not work with Binning corrections Make evaluation of Binning corrections JAX-traceable Nov 4, 2023
@eguiraud
Copy link
Owner Author

eguiraud commented Nov 4, 2023

@eguiraud
Copy link
Owner Author

eguiraud commented Nov 6, 2023

options I see:

  • re-implement scipy.interpolate.CubicSpline using XLA primitives or jax.numpy methods
  • switch to sigmoids as per Lukas' suggestion
  • switch to a sum of gaussians

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant