diff --git a/oryx/core/primitive.py b/oryx/core/primitive.py index 2ae0b2a..d841b2f 100644 --- a/oryx/core/primitive.py +++ b/oryx/core/primitive.py @@ -142,8 +142,6 @@ def _jvp(primals, tangents, **params): primals_out, tangents_out = ad.jvp(lu.wrap_init(self.impl, params)).call_wrapped( primals, tangents) - tangents_out = jax_util.safe_map(ad.recast_to_float0, primals_out, - tangents_out) return primals_out, tangents_out ad.primitive_jvps[self] = _jvp