From 17657d28d5a250f57028ecfc545c12423ec5864a Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Wed, 18 Sep 2024 16:41:09 +0000 Subject: [PATCH] fix (backends)(jax)(module.py): fixing the implementation for `train` and `eval` methods --- ivy/functional/backends/jax/module.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/ivy/functional/backends/jax/module.py b/ivy/functional/backends/jax/module.py index 27b005fb7547..31d63f416574 100644 --- a/ivy/functional/backends/jax/module.py +++ b/ivy/functional/backends/jax/module.py @@ -448,15 +448,24 @@ def register_parameter(self, name: str, value: jax.Array): def train(self, mode: bool = True): self._training = mode for module in self.children(): - if isinstance(module, nn.Module) and not hasattr(module, "train"): + if isinstance(module, Module): module.trainable = mode - continue - module.train(mode) + + super().train() self.trainable = mode return self - def eval(self): - return self.train(mode=False) + def eval( + self, + ): + self._training = False + for module in self.children(): + if isinstance(module, Module): + module.trainable = False + + super().eval() + self.trainable = False + return self def call(self, inputs, training=None, mask=None): raise NotImplementedError(