From fc5c88c585820dcc44d22c7f1545f8438d1d9f34 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sat, 2 Mar 2024 14:03:51 -0600 Subject: [PATCH] SGD simplification & optimizer accumulation --- optimi/sgd.py | 60 ++++++++++++++++++----------------------------- tests/sgd_test.py | 18 ++++++++++++-- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/optimi/sgd.py b/optimi/sgd.py index c9810eb..32c8951 100644 --- a/optimi/sgd.py +++ b/optimi/sgd.py @@ -159,6 +159,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=group["foreach"], gradient_release=False, + update_parameters=self._update_params, ) else: state = self.state[param] @@ -180,6 +181,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=False, gradient_release=True, + update_parameters=self._update_params, ) return loss @@ -201,6 +203,7 @@ def sgd( kahan_sum: bool = False, foreach: bool = False, gradient_release: bool = False, + update_parameters: bool = True, ): """Functional API to apply a SGD or SGDW optimization step. @@ -221,6 +224,7 @@ def sgd( kahan_sum: Enables Kahan summation for low precision `params` foreach: Enables the faster foreach implementation gradient_release: Fuses optimizer step as part of the parameter's backward pass + update_parameters: Accumulate gradients into optimizer states during step when False """ # calculate decoupled weight decay or fully decoupled weight decay if weight_decay != 0: @@ -250,6 +254,7 @@ def sgd( dampening=dampening, decouple_wd=(decouple_wd or decouple_lr), kahan_sum=kahan_sum, + update_parameters=update_parameters, ) @@ -265,6 +270,7 @@ def _single_sgd( dampening: bool, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): for i, param in enumerate(params): grad = grads[i] @@ -282,6 +288,7 @@ def _single_sgd( dampening=dampening, decouple_wd=decouple_wd, kahan_sum=kahan_sum, + update_parameters=update_parameters, ) @@ -297,23 +304,27 @@ def _single_param_sgd( dampening: bool, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): # decoupled weight decay, fully decoupled weight decay, or L2 weight decay - if weight_decay != 0: + if weight_decay != 0 and update_parameters: if decouple_wd: param.mul_(weight_decay) else: grad.add_(param, alpha=weight_decay) if momentum != 0: - # SGD Momentum + # SGD with Momentum if dampening: exp_avg.lerp_(grad, weight=1 - momentum) else: exp_avg.mul_(momentum).add_(grad) + else: + exp_avg = grad + if update_parameters: if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # SGD with Momentum step + # SGD step (regular step exp_agv = grad) kahan_comp.add_(exp_avg, alpha=-lr) # update weights with kahan compensation using grad as temp buffer @@ -323,22 +334,8 @@ def _single_param_sgd( # save error back to kahan compensation for next iteration kahan_comp.add_(grad.sub_(param)) else: - # SGD with Momentum step + # SGD step (regular step exp_agv = grad) param.add_(exp_avg, alpha=-lr) - else: - if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # SGD step - kahan_comp.add_(grad, alpha=-lr) - - # update weights with kahan compensation using grad as temp buffer - grad.copy_(param.detach()) - param.add_(kahan_comp) - - # save error back to kahan compensation for next iteration - kahan_comp.add_(grad.sub_(param)) - else: - # SGD step - param.add_(grad, alpha=-lr) def _foreach_sgd( @@ -353,26 +350,30 @@ def _foreach_sgd( dampening: bool, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, kahan_comps]) for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_kahan_comps), _) in grouped_tensors.items(): # decoupled weight decay, fully decoupled weight decay, or L2 weight decay - if weight_decay != 0: + if weight_decay != 0 and update_parameters: if decouple_wd: torch._foreach_mul_(dev_params, scalar=weight_decay) else: torch._foreach_add_(dev_grads, dev_params, alpha=weight_decay) if momentum != 0: - # SGD Momentum + # SGD with Momentum if dampening: torch._foreach_lerp_(dev_exp_avgs, dev_grads, weight=1 - momentum) else: torch._foreach_mul_(dev_exp_avgs, scalar=momentum) torch._foreach_add_(dev_exp_avgs, dev_grads, alpha=1) + else: + dev_exp_avgs = dev_grads + if update_parameters: if kahan_sum and dtype in [torch.float16, torch.bfloat16]: - # SGD with Momentum step + # SGD step (regular step exp_agv = grad) torch._foreach_add_(dev_kahan_comps, dev_exp_avgs, alpha=-lr) # update weights with kahan compensation using dev_grads as temp buffer @@ -383,20 +384,5 @@ def _foreach_sgd( torch._foreach_sub_(dev_grads, dev_params, alpha=1) torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1) else: - # SGD with Momentum step + # SGD step (regular step exp_agv = grad) torch._foreach_add_(dev_params, dev_exp_avgs, alpha=-lr) - else: - if kahan_sum and dtype in [torch.float16, torch.bfloat16]: - # SGD step - torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=-lr) - - # update weights with kahan compensation using dev_grads as temp buffer - torch._foreach_copy_(dev_grads, dev_params) - torch._foreach_add_(dev_params, dev_kahan_comps, alpha=1) - - # save error back to kahan compensation for next iteration - torch._foreach_sub_(dev_grads, dev_params, alpha=1) - torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1) - else: - # SGD step - torch._foreach_add_(dev_params, dev_grads, alpha=-lr) diff --git a/tests/sgd_test.py b/tests/sgd_test.py index ece1e52..2500486 100644 --- a/tests/sgd_test.py +++ b/tests/sgd_test.py @@ -8,7 +8,7 @@ from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype, cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype) + gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation) @@ -37,6 +37,7 @@ cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] @pytest.mark.cpu +@pytest.mark.sgd @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names) def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer) @@ -47,6 +48,7 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.sgd @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=80) @@ -57,7 +59,19 @@ def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, f cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.sgd @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), - iterations=80, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file + framework_opt_step=torch.rand(1).item() > 0.5) + + +@pytest.mark.cuda +@pytest.mark.sgd +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) +def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): + if optim_name in ["sgd", "sgd_l2"]: + pytest.skip("Skip tests for SGD and SGD with L2 weight decay.") + # SGD will error out more often if iterations is the default of 80 + optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), + iterations=20, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file