You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The evaluation of a scipy.interpolate.CubicSpline is not traceable by JAX (because of an explicit cast to np.array somewhere in there).
This in turn makes it impossible to pass CorrectionWithGrad.evaluate to jax.jit or jax.vmap if a Binning correction is involved.
for simple Binning (1D histos with scalar bin contents) see below.
for compound Binning (1D histos with Formulas or FormulaRefs as bin contents) there is the additional problem that JAX cannot trace through the bin look-up, and I'm not sure how to fix this
for MultiBinning (ND histos) I don't know of a jax-friendly implementation of a bin look-up differentiable relaxation, we might have to come up with one. MultiBinning is not supported at the moment anyways, tracked in Add support for MultiBinning #15
The text was updated successfully, but these errors were encountered:
eguiraud
changed the title
jax.jit does not work with Binning corrections
Make evaluation of Binning corrections JAX-traceable
Nov 4, 2023
The evaluation of a
scipy.interpolate.CubicSpline
is not traceable by JAX (because of an explicit cast tonp.array
somewhere in there).This in turn makes it impossible to pass
CorrectionWithGrad.evaluate
tojax.jit
orjax.vmap
if a Binning correction is involved.The text was updated successfully, but these errors were encountered: