Skip to content

Commit

Permalink
Resolve conflict with PyTorch >= 2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
HuFY-dev authored Mar 29, 2024
1 parent b6ba6cb commit 48ec5fe
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions sparse_autoencoder/optimizer/adam_with_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
This reset method is useful when resampling dead neurons during training.
"""
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from typing import Any
from typing_extensions import TypeAlias

from jaxtyping import Float, Int
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.optim import Adam
from torch.optim.optimizer import params_t

from sparse_autoencoder.tensor_types import Axis


# params_t was renamed to ParamsT in PyTorch 2.2, which caused import errors

Check failure on line 17 in sparse_autoencoder/optimizer/adam_with_reset.py

View workflow job for this annotation

GitHub Actions / Checks (3.11)

Ruff (I001)

sparse_autoencoder/optimizer/adam_with_reset.py:5:1: I001 Import block is un-sorted or un-formatted
# Copied from PyTorch 2.2 with modifications for better style
ParamsT: TypeAlias = Iterable[Tensor] | Iterable[dict[str, Any]]

class AdamWithReset(Adam):
"""Adam Optimizer with a reset method.
Expand All @@ -35,7 +40,7 @@ class AdamWithReset(Adam):

def __init__( # (extending existing implementation)
self,
params: params_t,
params: ParamsT,
lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
Expand Down

0 comments on commit 48ec5fe

Please sign in to comment.