Skip to content

Commit

Permalink
Refactored CI and fixed types in tests/test_core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ArvinSKushwaha committed Jun 5, 2024
1 parent 741f302 commit 50bdfe1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ jobs:
run: rye test

- name: Type-Check
run: rye run basedpyright -p . .
run: rye run ci
15 changes: 8 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "jaximal"
version = "0.1.1"
description = "A JAX-based PyTree manipulation library "
authors = [
{ name = "Arvin Kushwaha", email = "[email protected]" }
{ name = "Arvin Kushwaha", email = "[email protected]" },
]
dependencies = [
"safetensors>=0.4.3",
Expand All @@ -22,12 +22,13 @@ build-backend = "hatchling.build"

[tool.rye]
managed = true
dev-dependencies = [
"pytest>=8.2.1",
"basedpyright>=1.12.4",
]
excluded-dependencies = [
]
dev-dependencies = ["pytest>=8.2.1", "basedpyright>=1.12.4"]
excluded-dependencies = []

[tool.rye.scripts]
ci = { chain = ["ci:verifytypes", "ci:basedpyright"] }
"ci:verifytypes" = "rye run basedpyright --verifytypes jaximal"
"ci:basedpyright" = "rye run basedpyright -p . ."

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
4 changes: 2 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(
optimizer = optax.chain(optax.adam(1e-1), optax.contrib.reduce_on_plateau())

mlp = MLP.init_state(3, 4, [16, 16], mlp_key)
opt_state: optax.OptState = optimizer.init(cast(optax.OptState, mlp))
opt_state: optax.OptState = optimizer.init(cast(optax.Params, mlp))

def loss(
mlp: MLP,
Expand All @@ -104,7 +104,7 @@ def update(
value=cost,
)

mlp: MLP = optax.apply_updates(mlp, updates) # type: ignore
mlp = cast(MLP, optax.apply_updates(cast(optax.Params, mlp), updates))

jax.debug.print('{} {}', i, cost)

Expand Down

0 comments on commit 50bdfe1

Please sign in to comment.