Skip to content

Commit

Permalink
SGD simplification & optimizer accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Mar 2, 2024
1 parent 7bb461a commit fc5c88c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
60 changes: 23 additions & 37 deletions optimi/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=group["foreach"],
gradient_release=False,
update_parameters=self._update_params,
)
else:
state = self.state[param]
Expand All @@ -180,6 +181,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None):
kahan_sum=group["kahan_sum"],
foreach=False,
gradient_release=True,
update_parameters=self._update_params,
)

return loss
Expand All @@ -201,6 +203,7 @@ def sgd(
kahan_sum: bool = False,
foreach: bool = False,
gradient_release: bool = False,
update_parameters: bool = True,
):
"""Functional API to apply a SGD or SGDW optimization step.
Expand All @@ -221,6 +224,7 @@ def sgd(
kahan_sum: Enables Kahan summation for low precision `params`
foreach: Enables the faster foreach implementation
gradient_release: Fuses optimizer step as part of the parameter's backward pass
update_parameters: Accumulate gradients into optimizer states during step when False
"""
# calculate decoupled weight decay or fully decoupled weight decay
if weight_decay != 0:
Expand Down Expand Up @@ -250,6 +254,7 @@ def sgd(
dampening=dampening,
decouple_wd=(decouple_wd or decouple_lr),
kahan_sum=kahan_sum,
update_parameters=update_parameters,
)


Expand All @@ -265,6 +270,7 @@ def _single_sgd(
dampening: bool,
decouple_wd: bool,
kahan_sum: bool = False,
update_parameters: bool = True,
):
for i, param in enumerate(params):
grad = grads[i]
Expand All @@ -282,6 +288,7 @@ def _single_sgd(
dampening=dampening,
decouple_wd=decouple_wd,
kahan_sum=kahan_sum,
update_parameters=update_parameters,
)


Expand All @@ -297,23 +304,27 @@ def _single_param_sgd(
dampening: bool,
decouple_wd: bool,
kahan_sum: bool = False,
update_parameters: bool = True,
):
# decoupled weight decay, fully decoupled weight decay, or L2 weight decay
if weight_decay != 0:
if weight_decay != 0 and update_parameters:
if decouple_wd:
param.mul_(weight_decay)
else:
grad.add_(param, alpha=weight_decay)

if momentum != 0:
# SGD Momentum
# SGD with Momentum
if dampening:
exp_avg.lerp_(grad, weight=1 - momentum)
else:
exp_avg.mul_(momentum).add_(grad)
else:
exp_avg = grad

if update_parameters:
if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# SGD with Momentum step
# SGD step (regular step exp_agv = grad)
kahan_comp.add_(exp_avg, alpha=-lr)

# update weights with kahan compensation using grad as temp buffer
Expand All @@ -323,22 +334,8 @@ def _single_param_sgd(
# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))
else:
# SGD with Momentum step
# SGD step (regular step exp_agv = grad)
param.add_(exp_avg, alpha=-lr)
else:
if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# SGD step
kahan_comp.add_(grad, alpha=-lr)

# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)

# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))
else:
# SGD step
param.add_(grad, alpha=-lr)


def _foreach_sgd(
Expand All @@ -353,26 +350,30 @@ def _foreach_sgd(
dampening: bool,
decouple_wd: bool,
kahan_sum: bool = False,
update_parameters: bool = True,
):
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, kahan_comps])
for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_kahan_comps), _) in grouped_tensors.items():
# decoupled weight decay, fully decoupled weight decay, or L2 weight decay
if weight_decay != 0:
if weight_decay != 0 and update_parameters:
if decouple_wd:
torch._foreach_mul_(dev_params, scalar=weight_decay)
else:
torch._foreach_add_(dev_grads, dev_params, alpha=weight_decay)

if momentum != 0:
# SGD Momentum
# SGD with Momentum
if dampening:
torch._foreach_lerp_(dev_exp_avgs, dev_grads, weight=1 - momentum)
else:
torch._foreach_mul_(dev_exp_avgs, scalar=momentum)
torch._foreach_add_(dev_exp_avgs, dev_grads, alpha=1)
else:
dev_exp_avgs = dev_grads

if update_parameters:
if kahan_sum and dtype in [torch.float16, torch.bfloat16]:
# SGD with Momentum step
# SGD step (regular step exp_agv = grad)
torch._foreach_add_(dev_kahan_comps, dev_exp_avgs, alpha=-lr)

# update weights with kahan compensation using dev_grads as temp buffer
Expand All @@ -383,20 +384,5 @@ def _foreach_sgd(
torch._foreach_sub_(dev_grads, dev_params, alpha=1)
torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1)
else:
# SGD with Momentum step
# SGD step (regular step exp_agv = grad)
torch._foreach_add_(dev_params, dev_exp_avgs, alpha=-lr)
else:
if kahan_sum and dtype in [torch.float16, torch.bfloat16]:
# SGD step
torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=-lr)

# update weights with kahan compensation using dev_grads as temp buffer
torch._foreach_copy_(dev_grads, dev_params)
torch._foreach_add_(dev_params, dev_kahan_comps, alpha=1)

# save error back to kahan compensation for next iteration
torch._foreach_sub_(dev_grads, dev_params, alpha=1)
torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1)
else:
# SGD step
torch._foreach_add_(dev_params, dev_grads, alpha=-lr)
18 changes: 16 additions & 2 deletions tests/sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype,
cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1,
gr_dim2, gr_dtype, gr_ftype)
gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation)



Expand Down Expand Up @@ -37,6 +37,7 @@
cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values]

@pytest.mark.cpu
@pytest.mark.sgd
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names)
def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str):
run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer)
Expand All @@ -47,6 +48,7 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft
cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values]

@pytest.mark.cuda
@pytest.mark.sgd
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names)
def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str):
run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=80)
Expand All @@ -57,7 +59,19 @@ def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, f
cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values]

@pytest.mark.cuda
@pytest.mark.sgd
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names)
def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str):
gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'),
iterations=80, framework_opt_step=torch.rand(1).item() > 0.5)
framework_opt_step=torch.rand(1).item() > 0.5)


@pytest.mark.cuda
@pytest.mark.sgd
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names)
def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str):
if optim_name in ["sgd", "sgd_l2"]:
pytest.skip("Skip tests for SGD and SGD with L2 weight decay.")
# SGD will error out more often if iterations is the default of 80
optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'),
iterations=20, framework_opt_step=torch.rand(1).item() > 0.5)

0 comments on commit fc5c88c

Please sign in to comment.