From 8cd8bf573f02d38791f8e7fbffd8cdf567cd5a0d Mon Sep 17 00:00:00 2001 From: Arvin Kushwaha Date: Fri, 26 Jul 2024 17:02:52 +0200 Subject: [PATCH] Fixed `Sequential` to not have extra arguments. If extra arguments are necessary, they should be packaged into the data --- src/jaximal/nn.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/jaximal/nn.py b/src/jaximal/nn.py index a5c91ac..8cbe37b 100644 --- a/src/jaximal/nn.py +++ b/src/jaximal/nn.py @@ -139,13 +139,9 @@ def init(key: PRNGKeyArray) -> Self: return init - def __call__(self, data: Any, *args: dict[str, Any]) -> Any: - assert len(args) == len(self.modules), ( - 'Expected `self.modules` and `args` to have the same length ' - f'but got {len(self.modules)} and {len(args)}, respectively.' - ) - for kwargs, modules in zip(args, self.modules): - data = modules(data, **kwargs) + def __call__(self, data: Any) -> Any: + for modules in zip(args, self.modules): + data = modules(data) return data