diff --git a/sparse_autoencoder/optimizer/adam_with_reset.py b/sparse_autoencoder/optimizer/adam_with_reset.py index dd1921c7..2dee9b15 100644 --- a/sparse_autoencoder/optimizer/adam_with_reset.py +++ b/sparse_autoencoder/optimizer/adam_with_reset.py @@ -4,15 +4,16 @@ """ from collections.abc import Iterable, Iterator from typing import Any -from typing_extensions import TypeAlias from jaxtyping import Float, Int from torch import Tensor from torch.nn.parameter import Parameter from torch.optim import Adam +from typing_extensions import TypeAlias from sparse_autoencoder.tensor_types import Axis + # params_t was renamed to ParamsT in PyTorch 2.2, which caused import errors # Copied from PyTorch 2.2 with modifications for better style ParamsT: TypeAlias = Iterable[Tensor] | Iterable[dict[str, Any]]