Skip to content

Commit

Permalink
Fixed Sequential to not have extra arguments. If extra arguments ar…
Browse files Browse the repository at this point in the history
…e necessary, they should be packaged into the data
  • Loading branch information
ArvinSKushwaha committed Jul 26, 2024
1 parent d990d42 commit 8cd8bf5
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions src/jaximal/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,9 @@ def init(key: PRNGKeyArray) -> Self:

return init

def __call__(self, data: Any, *args: dict[str, Any]) -> Any:
assert len(args) == len(self.modules), (
'Expected `self.modules` and `args` to have the same length '
f'but got {len(self.modules)} and {len(args)}, respectively.'
)
for kwargs, modules in zip(args, self.modules):
data = modules(data, **kwargs)
def __call__(self, data: Any) -> Any:
for modules in zip(args, self.modules):

Check failure on line 143 in src/jaximal/nn.py

View workflow job for this annotation

GitHub Actions / Build and Test (ubuntu-latest)

"args" is not defined (reportUndefinedVariable)
data = modules(data)

Check failure on line 144 in src/jaximal/nn.py

View workflow job for this annotation

GitHub Actions / Build and Test (ubuntu-latest)

Object of type "tuple[Unknown, JaximalModule]" is not callable   Attribute "__call__" is unknown (reportCallIssue)

return data

Expand Down

0 comments on commit 8cd8bf5

Please sign in to comment.