Skip to content

Commit

Permalink
StableAdamW optimizer accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Mar 2, 2024
1 parent ad0eaeb commit 7bb461a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
9 changes: 9 additions & 0 deletions optimi/stableadamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=group["foreach"],
gradient_release=False,
update_parameters=self._update_params,
)
else:
state = self.state[param]
Expand All @@ -194,6 +195,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=False,
gradient_release=True,
update_parameters=self._update_params,
)

return loss
Expand All @@ -218,6 +220,7 @@ def stableadamw(
kahan_sum: bool = False,
foreach: bool = False,
gradient_release: bool = False,
update_parameters: bool = True,
):
"""Functional API to apply a StableAdamW optimization step.
Expand All @@ -241,6 +244,7 @@ def stableadamw(
kahan_sum: Enables Kahan summation for low precision parameters
foreach: Enables the faster foreach implementation
gradient_release: Fuses optimizer step as part of the parameter's backward pass
update_parameters: Accumulate gradients into optimizer states during step when False
"""
# calculate debiased beta hat & complement terms
step.add_(1)
Expand Down Expand Up @@ -272,6 +276,7 @@ def stableadamw(
decouple_lr=decouple_lr,
max_lr=max_lr,
kahan_sum=kahan_sum,
update_parameters=update_parameters,
)


Expand All @@ -291,6 +296,7 @@ def _single_stableadamw(
decouple_lr: bool,
max_lr: float | None = None,
kahan_sum: bool = False,
update_parameters: bool = True,
):
for i, param in enumerate(params):
grad = grads[i]
Expand All @@ -314,6 +320,7 @@ def _single_stableadamw(
decouple_lr=decouple_lr,
max_lr=max_lr,
kahan_sum=kahan_sum,
update_parameters=update_parameters,
)


Expand All @@ -333,6 +340,7 @@ def _single_param_stableadamw(
decouple_lr: bool,
max_lr: float | None = None,
kahan_sum: bool = False,
update_parameters: bool = True,
):
# update gradient moving averages with debiased betas
exp_avg.lerp_(grad, weight=beta1_comp)
Expand Down Expand Up @@ -383,6 +391,7 @@ def _foreach_stableadamw(
decouple_lr: bool,
max_lr: float | None = None,
kahan_sum: bool = False,
update_parameters: bool = True,
):
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, eps_sqs, kahan_comps])
for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_exp_avg_sqs, dev_eps_sqs, dev_kahan_comps), _) in grouped_tensors.items():
Expand Down
16 changes: 14 additions & 2 deletions tests/stableadam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype,
cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1,
gr_dim2, gr_dtype, gr_ftype)
gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation)



Expand All @@ -31,6 +31,7 @@
cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values]

@pytest.mark.cpu
@pytest.mark.stableadam
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names)
def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str):
run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer)
Expand All @@ -41,16 +42,27 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft
cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values]

@pytest.mark.cuda
@pytest.mark.stableadam
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names)
def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str):
run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=80)



cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype))
cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values]

@pytest.mark.cuda
@pytest.mark.stableadam
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names)
def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str):
gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'),
iterations=80, framework_opt_step=torch.rand(1).item() > 0.5)
framework_opt_step=torch.rand(1).item() > 0.5)


@pytest.mark.cuda
@pytest.mark.stableadam
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names)
def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str):
optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'),
framework_opt_step=torch.rand(1).item() > 0.5)

0 comments on commit 7bb461a

Please sign in to comment.