From ab13509a1db444318c6308bac0aac526891d31af Mon Sep 17 00:00:00 2001 From: SamTov Date: Wed, 9 Aug 2023 11:46:06 +0200 Subject: [PATCH] Play with optimizer options --- znnl/models/jax_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/znnl/models/jax_model.py b/znnl/models/jax_model.py index 5bde9cc..6b18cef 100644 --- a/znnl/models/jax_model.py +++ b/znnl/models/jax_model.py @@ -81,11 +81,11 @@ def __init__( self.init_model(seed) # Prepare NTK calculation - self.empirical_ntk = nt.empirical_ntk_fn( + self.empirical_ntk = nt.batch(nt.empirical_ntk_fn( f=self._ntk_apply_fn, trace_axes=trace_axes - ) + ), batch_size=ntk_batch_size) - self.empirical_ntk_jit = jax.jit(self.empirical_ntk) + self.empirical_ntk_jit = self.empirical_ntk def init_model( self,