diff --git a/README.md b/README.md index c426434..be0ef27 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,25 @@ # optimī -### Fast, Modern, and Low Precision PyTorch Optimizers +### Fast, Modern, Memory Efficient, and Low Precision PyTorch Optimizers -optimi enables accurate low precision training via Kahan summation, supports fully decoupled weight decay, and features fast implementations of modern optimizers. +optimi enables accurate low precision training via Kahan summation, integrates gradient release and optimizer accumulation for additional memory efficiency, supports fully decoupled weight decay, and features fast implementations of modern optimizers. ## Low Precision Training with Kahan Summation -optimi optimizers can match the performance of mixed precision when [training in BFloat16 by using Kahan summation](https://optimi.benjaminwarner.dev/kahan_summation). +optimi optimizers can nearly reach or match the performance of mixed precision when [training in BFloat16 by using Kahan summation](https://optimi.benjaminwarner.dev/kahan_summation). Training in BFloat16 with Kahan summation can reduce non-activation training memory usage by [37.5 to 45.5 percent](https://optimi.benjaminwarner.dev/kahan_summation/#memory-savings) when using an Adam optimizer. BFloat16 training increases single GPU [training speed by ~10 percent](https://optimi.benjaminwarner.dev/kahan_summation/#training-speedup) at the same batch size. +## Gradient Release: Fused Backward and Optimizer Step + +optimi optimizers can perform the [optimization step layer-by-layer during the backward pass](https://optimi.benjaminwarner.dev/gradient_release), immediately freeing gradient memory. + +Unlike the current PyTorch implementation, optimi’s gradient release optimizers are a drop-in replacement for standard optimizers and seamlessly work with exisiting hyperparmeter schedulers. + +## Optimizer Accumulation: Gradient Release and Accumulation + +optimi optimizers can approximate gradient accumulation with gradient release by [accumulating gradients into the optimizer states](https://optimi.benjaminwarner.dev/optimizer_accumulation). + ## Fully Decoupled Weight Decay In addition to supporting PyTorch-style decoupled weight decay, optimi optimizers also support [fully decoupled weight decay](https://optimi.benjaminwarner.dev/fully_decoupled_weight_decay). @@ -44,7 +54,7 @@ from optimi import AdamW # create or cast model in low precision (bfloat16) model = nn.Linear(20, 1, dtype=torch.bfloat16) -# instantiate AdamW with parameters and fully decoupled weight decay +# initialize any optimi optimizer with parameters & fully decoupled weight decay # Kahan summation is automatically enabled since model & inputs are bfloat16 opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5, decouple_lr=True) @@ -63,10 +73,64 @@ To use with PyTorch-style weight decay with float32 or mixed precision: # create model model = nn.Linear(20, 1) -# instantiate AdamW with parameters +# initialize any optimi optimizer with parameters opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) ``` +To use with gradient release: + +```python +# initialize any optimi optimizer with `gradient_release=True` +# and call `prepare_for_gradient_release` on model and optimizer +opt = AdamW(model.parameters(), lr=1e-3, gradient_release=True) +prepare_for_gradient_release(model, opt) + +# calling backward on the model will peform the optimzier step +loss = model(torch.randn(20, dtype=torch.bfloat16)) +loss.backward() + +# optimizer step and zero_grad are no longer needed, and will +# harmlessly no-op if called by an existing training framework +# opt.step() +# opt.zero_grad() + +# optionally remove gradient release hooks when done training +remove_gradient_release(model) +``` + +To use with optimizer accumulation: + +```python +# initialize any optimi optimizer with `gradient_release=True` +# and call `prepare_for_gradient_release` on model and optimizer +opt = AdamW(model.parameters(), lr=1e-3, gradient_release=True) +prepare_for_gradient_release(model, opt) + +# update model parameters every four steps after accumulating +# gradients directly into the optimizer states +accumulation_steps = 4 + +# use existing PyTorch dataloader +for idx, batch in enumerate(dataloader): + # `optimizer_accumulation=True` accumulates gradients into + # optimizer states. set `optimizer_accumulation=False` to + # update parameters by performing a full gradient release step + opt.optimizer_accumulation = (idx+1) % accumulation_steps != 0 + + # calling backward on the model will peform the optimizer step + # either accumulating gradients or updating model parameters + loss = model(batch) + loss.backward() + + # optimizer step and zero_grad are no longer needed, and will + # harmlessly no-op if called by an existing training framework + # opt.step() + # opt.zero_grad() + +# optionally remove gradient release hooks when done training +remove_gradient_release(model) +``` + ## Differences from PyTorch optimi optimizers do not support compilation, differentiation, complex numbers, or have capturable versions. diff --git a/docs/css/extra.css b/docs/css/extra.css index d5eb86e..2f1220b 100644 --- a/docs/css/extra.css +++ b/docs/css/extra.css @@ -46,8 +46,8 @@ --md-typeset-table-color: rgba(24, 24, 24, 0.05); /* Code highlighting color shades */ - /* --md-code-hl-color: #9a3fe4; - --md-code-hl-color--light: #9a3fe4; */ + --md-code-hl-color: #9a3fe4; + --md-code-hl-color--light: #9a3fe43c; --md-code-hl-number-color: #db5f00; --md-code-hl-special-color: #d32300; --md-code-hl-function-color: #cc9901; @@ -161,7 +161,7 @@ /* Links */ -.md-content a:not(.headerlink):not(.footnote-ref):not(.footnote-backref) { +.md-content a:not(.headerlink):not(.footnote-ref):not(.footnote-backref):not(:has(> code)) { box-shadow: inset 0 -0.115rem 0 var(--light-purple); text-decoration: none; transition: all .15s cubic-bezier(.33,.66,.66,1); @@ -171,6 +171,18 @@ color: var(--black) } } +.md-content a code { + box-shadow: inset 0 -0.115rem 0 var(--light-purple); + text-decoration: none; + transition: all .15s cubic-bezier(.33,.66,.66,1); + z-index: 10; + border-bottom-left-radius: 0; + border-bottom-right-radius: 0; + + &:hover { box-shadow: inset 0 -2rem 0 var(--dark-purple); + color: var(--black) } +} + /* Katex */ .katex-display { margin-top: 0 !important; diff --git a/docs/foreach.md b/docs/foreach.md index f93aa36..5956fd3 100644 --- a/docs/foreach.md +++ b/docs/foreach.md @@ -8,7 +8,7 @@ Like PyTorch, optimi supports foreach implementations of all optimizers. Foreach Foreach implementations can increase optimizer peak memory usage. optimi attempts to reduce this extra overhead by reusing the gradient buffer for temporary variables. If the gradients are required between the optimization step and [gradient reset step](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html#torch.optim.Optimizer.zero_grad), set `foreach=False` to use the for-loop implementation. -??? warning "Important: Foreach Requires PyTorch 2.1+" +??? note "Note: Foreach Requires PyTorch 2.1+" optimi’s foreach implementations require PyTorch 2.1 or newer. diff --git a/docs/gradient_release.md b/docs/gradient_release.md index ed176bb..b02381c 100644 --- a/docs/gradient_release.md +++ b/docs/gradient_release.md @@ -1,16 +1,19 @@ --- -title: "Gradient Release: Fused Backward and Optimizer Step" +title: "Gradient Release" +description: "Fused Backward Pass and Optimizer Step" --- -# Gradient Release: Fused Backward and Optimizer Step +# Gradient Release -Gradient release reduces training memory by limiting gradients to one layer at any given time. Unlike [PyTorch’s implementation](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html), optimi’s gradient release is fully compatible with existing learning rate and optimizer schedulers and training frameworks. +**Fused Backward Pass and Optimizer Step** + +Gradient release reduces training memory by limiting gradients to one layer at any given time. Unlike [PyTorch’s implementation](https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html), optimi’s gradient release is fully compatible with both existing learning rate and optimizer schedulers and existing training frameworks. During the backward pass, each model layer calculates its gradients, performs the optimizer step, and clears the gradients before proceeding to the backward pass for the next layer. This fused backward and optimizer step can reduce non-activation memory usage by ~25 percent for an Adam optimizer. -Gradient release can be combined with other techniques such as [Kahan summation](kahan_summation.md) or activation checkpointing for further memory savings. +Gradient release can also be combined with other techniques such as [Kahan summation](kahan_summation.md) or [activation checkpointing](https://pytorch.org/docs/stable/checkpoint.html) for further memory savings. -??? warning "Important: Gradient Release Requires PyTorch 2.1+" +??? note "Note: Gradient Release Requires PyTorch 2.1+" Gradient release requires PyTorch 2.1 or newer. @@ -18,10 +21,20 @@ Gradient release was proposed by Pudipeddi et al in [*Training Large Neural Netw ## Limitations and Workarounds -Since gradient release immediately frees the gradient during the backward pass, features which rely on persistent gradients like gradient clipping or gradient accumulation won’t work. +Since gradient release immediately frees the gradient during the backward pass, features which rely on persistent gradients like AMP's `GradScaler`, gradient clipping, or gradient accumulation won’t work. + +!!! warning "Important: Gradient Release is Incompatible with FP16 Mixed Precision" + + Gradient release is incompatible with Float16 Automatic Mixed Precision since PyTorch's `GradScaler` requires access to the entire model's gradients for the optimizer step. + + Use BFloat16 Automatic Mixed Precision instead. The recommended workaround for gradient clipping is to use [StableAdamW](optimizers/stableadamw.md) instead of Adam or AdamW, as StableAdamW removes the need for gradient clipping by porting Adafactor’s update clipping into AdamW. +??? tip "Tip: Use Optimizer Accumulation to Approximate Gradient Accumulation" + + optimi's [optimizer accumulation](optimizer_accumulation.md) approximates [gradient accumlation](https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation) by defering parameter updates while accumulating gradients directly into the optimizer states. + One potential workaround for gradient accumulation is to increase the optimizer’s momentum or $\beta_1$ to approximate accumulating gradients across multiple batches. ## Example @@ -45,10 +58,10 @@ prepare_for_gradient_release(model, opt) loss = model(torch.randn(20, dtype=torch.bfloat16)) loss.backward() -# optimizer step and sero_grad is no longer needed, and -# will no-op if called by an existing training framework -opt.step() -opt.zero_grad() +# optimizer step and zero_grad are no longer needed, and will +# harmlessly no-op if called by an existing training framework +# opt.step() +# opt.zero_grad() # optionally remove gradient release hooks when done training remove_gradient_release(model) diff --git a/docs/index.md b/docs/index.md index 1d31255..29541c8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,17 +1,17 @@ --- title: "optimi" -description: "Fast, Modern, and Low Precision PyTorch Optimizers" +description: "Fast, Modern, Memory Efficient, and Low Precision PyTorch Optimizers" --- # optimī -**Fast, Modern, and Low Precision PyTorch Optimizers** +**Fast, Modern, Memory Efficient, and Low Precision PyTorch Optimizers** -optimi enables accurate low precision training via Kahan summation, supports fully decoupled weight decay, and features fast implementations of modern optimizers. +optimi enables accurate low precision training via Kahan summation, integrates gradient release and optimizer accumulation for additional memory efficiency, supports fully decoupled weight decay, and features fast implementations of modern optimizers. ## Low Precision Training with Kahan Summation -optimi optimizers can match the performance of mixed precision when [training in BFloat16 by using Kahan summation](kahan_summation.md). +optimi optimizers can nearly reach or match the performance of mixed precision when [training in BFloat16 by using Kahan summation](kahan_summation.md). Training in BFloat16 with Kahan summation can reduce non-activation training memory usage by [37.5 to 45.5 percent](kahan_summation.md/#memory-savings) when using an Adam optimizer. BFloat16 training increases single GPU [training speed by ~10 percent](kahan_summation.md/#training-speedup) at the same batch size. @@ -21,6 +21,10 @@ optimi optimizers can perform the [optimization step layer-by-layer during the b Unlike the current PyTorch implementation, optimi’s gradient release optimizers are a drop-in replacement for standard optimizers and seamlessly work with exisiting hyperparmeter schedulers. +## Optimizer Accumulation: Gradient Release and Accumulation + +optimi optimizers can approximate gradient accumulation with gradient release by [accumulating gradients into the optimizer states](optimizer_accumulation.md). + ## Fully Decoupled Weight Decay In addition to supporting PyTorch-style decoupled weight decay, optimi optimizers also support [fully decoupled weight decay](fully_decoupled_weight_decay.md). @@ -51,7 +55,7 @@ from optimi import AdamW # create or cast model in low precision (bfloat16) model = nn.Linear(20, 1, dtype=torch.bfloat16) -# initialize AdamW with parameters and fully decoupled weight decay +# initialize any optimi optimizer with parameters & fully decoupled weight decay # Kahan summation is automatically enabled since model & inputs are bfloat16 opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5, decouple_lr=True) @@ -70,18 +74,15 @@ To use with PyTorch-style weight decay with float32 or mixed precision: # create model model = nn.Linear(20, 1) -# initialize AdamW with parameters +# initialize any optimi optimizer with parameters opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) ``` To use with gradient release: ```python -# create model -model = nn.Linear(20, 1) - -# initialize AdamW with `gradient_release=True` and call -# `prepare_for_gradient_release` on model and optimizer +# initialize any optimi optimizer with `gradient_release=True` +# and call `prepare_for_gradient_release` on model and optimizer opt = AdamW(model.parameters(), lr=1e-3, gradient_release=True) prepare_for_gradient_release(model, opt) @@ -89,10 +90,43 @@ prepare_for_gradient_release(model, opt) loss = model(torch.randn(20, dtype=torch.bfloat16)) loss.backward() -# optimizer step and sero_grad is no longer needed, and -# will no-op if called by an existing training framework -opt.step() -opt.zero_grad() +# optimizer step and zero_grad are no longer needed, and will +# harmlessly no-op if called by an existing training framework +# opt.step() +# opt.zero_grad() + +# optionally remove gradient release hooks when done training +remove_gradient_release(model) +``` + +To use with optimizer accumulation: + +```python +# initialize any optimi optimizer with `gradient_release=True` +# and call `prepare_for_gradient_release` on model and optimizer +opt = AdamW(model.parameters(), lr=1e-3, gradient_release=True) +prepare_for_gradient_release(model, opt) + +# update model parameters every four steps after accumulating +# gradients directly into the optimizer states +accumulation_steps = 4 + +# use existing PyTorch dataloader +for idx, batch in enumerate(dataloader): + # `optimizer_accumulation=True` accumulates gradients into + # optimizer states. set `optimizer_accumulation=False` to + # update parameters by performing a full gradient release step + opt.optimizer_accumulation = (idx+1) % accumulation_steps != 0 + + # calling backward on the model will peform the optimizer step + # either accumulating gradients or updating model parameters + loss = model(batch) + loss.backward() + + # optimizer step and zero_grad are no longer needed, and will + # harmlessly no-op if called by an existing training framework + # opt.step() + # opt.zero_grad() # optionally remove gradient release hooks when done training remove_gradient_release(model) diff --git a/docs/kahan_summation.md b/docs/kahan_summation.md index 6937a76..5f25208 100644 --- a/docs/kahan_summation.md +++ b/docs/kahan_summation.md @@ -4,7 +4,7 @@ title: Low Precision Training with Kahan Summation # Low Precision Training with Kahan Summation -While training models in low precision (Float16 or BFloat16) usually does not match training in full precision (Float32) or [mixed precision](https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch), optimi optimizers match the performance of mixed precision when training in BFloat16 by using Kahan summation[^1]. +While training models in low precision (Float16 or BFloat16) usually differs from training in full precision (Float32) or [mixed precision](https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch), optimi optimizers nearly reach or match the performance of mixed precision when training in BFloat16 by using Kahan summation[^1]. Training in low precision [reduces memory usage](#memory-savings) and increases [training speed](#training-speedup) relative to mixed precision training. @@ -108,6 +108,6 @@ $$ This shows the optimi implementation of Kahan summation optimizers, which is equivalent to the *Revisiting BFloat16 Training* formulation. -[^1]: Current testing on small models shows no degradation in training performance. +[^1]: Current testing on small models shows little to no degradation in model performance. [^2]: Also known as Kahan–Babuška summation or compensated summation. \ No newline at end of file diff --git a/docs/optimizer_accumulation.md b/docs/optimizer_accumulation.md new file mode 100644 index 0000000..815cc14 --- /dev/null +++ b/docs/optimizer_accumulation.md @@ -0,0 +1,88 @@ +--- +title: "Optimizer Accumulation" +description: "Gradient Release with Approximate Gradient Accumulation" +--- + +# Optimizer Accumulation + +**Gradient Release with Approximate Gradient Accumulation** + +[Gradient accumulation](https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation) reduces training memory by splitting a batch into micro-batches and accumulating micro-batch gradients into the larger batch. [Gradient release](gradient_release.md) reduces training memory by limiting gradients to one layer at any given time. Optimizer accumulation unifies these two disparate approaches by accumulating gradients directly into optimizer states while performing gradient release. + +During the backward pass, each model layer calculates its gradients, performs a partial optimizer step, and clears the gradients before proceeding to the backward pass for the next layer. The partial optimizer step accumulates gradients by updating the optimizer state but not modifying the model weights. After multiple gradients have been accumulated into optimizer states, a normal optimizer step is ran updating the model parameters with the accumulated states. + +Optimizer accumulation can reduce non-activation memory usage by ~40 percent compared to an Adam optimizer with gradient accumulation. Optimizer accumulation can also be combined with other techniques such as [Kahan summation](kahan_summation.md) or [activation checkpointing](https://pytorch.org/docs/stable/checkpoint.html) for further memory savings. + +??? note "Note: Optimizer Accumulation Requires PyTorch 2.1+" + + Optimizer accumulation requires PyTorch 2.1 or newer. + +Optimizer accumulation was proposed by Zhang et al in [*AdamAccumulation to Reduce Memory Footprints of both Activations and Gradients for Large-scale DNN Training*](https://arxiv.org/abs/2305.19982). optimi’s implementation enables AdamAccumulation for all optimi optimizers[^1]. + +Zhang et al report that models trained with an AdamAccumulation over eight micro-batches match models trained via Adam with gradient accumulation over eight micro-batches. + +## Limitations and Workarounds + +Since optimizer accumulation immediately frees the gradient during the backward pass, features which rely on persistent gradients like AMP's `GradScaler`, gradient clipping, or gradient accumulation won’t work. L2 weight decay also shouldn’t be used with optimizer accumulation. + +!!! warning "Important: Optimizer Accumulation is Incompatible with FP16 Mixed Precision" + + Optimizer accumulation is incompatible with Float16 Automatic Mixed Precision since PyTorch's `GradScaler` requires access to the entire model's gradients for the optimizer step. + + Use BFloat16 Automatic Mixed Precision instead. + +The recommended workaround for gradient clipping is to use [StableAdamW](optimizers/stableadamw.md) instead of Adam or AdamW, as StableAdamW removes the need for gradient clipping by porting Adafactor’s update clipping into AdamW. + +!!! warning "Important: Don't use L2 Weight Decay with Optimizer Accumulation" + + optimi applies weight decay on the full optimization step. Since L2 weight decay operates on the gradients, it would only be applied on the last gradient instead of all gradients. + + Use decoupled or [fully decoupled weight decay](fully_decoupled_weight_decay.md) instead. + +Because the gradients are accumulated into the optimizer states, applying beta and momentum terms, optimizer accumulation approximates gradient accumulation. + +## Example + +Using optimi’s optimizer accumulation requires three steps: initializing the optimizer with `gradient_release=True`, calling `prepare_for_gradient_release` on both the model and optimizer, and setting `optimizer.optimizer_accumulation` to True or False to accumulation gradients or perform a full optimizer step, respectively. + +Like gradient accumulation, set `optimizer_accumulation=True` before the backward step while accumulating gradients and `optimizer_accumulation=False` when model parameters are to be updated by the full optimizer step. + +```python +import torch +from torch import nn +from optimi import AdamW + +# create or cast model in low precision (bfloat16) +model = nn.Linear(20, 1, dtype=torch.bfloat16) + +# initialize any optimi optimizer with `gradient_release=True` +# and call `prepare_for_gradient_release` on model and optimizer +opt = AdamW(model.parameters(), lr=1e-3, gradient_release=True) +prepare_for_gradient_release(model, opt) + +# update model parameters every four steps after accumulating +# gradients directly into the optimizer states +accumulation_steps = 4 + +# use existing PyTorch dataloader +for idx, batch in enumerate(dataloader): + # `optimizer_accumulation=True` accumulates gradients into + # optimizer states. set `optimizer_accumulation=False` to + # update parameters by performing a full gradient release step + opt.optimizer_accumulation = (idx+1) % accumulation_steps != 0 + + # calling backward on the model will peform the optimizer step + # either accumulating gradients or updating model parameters + loss = model(batch) + loss.backward() + + # optimizer step and zero_grad are no longer needed, and will + # harmlessly no-op if called by an existing training framework + # opt.step() + # opt.zero_grad() + +# optionally remove gradient release hooks when done training +remove_gradient_release(model) +``` + +[^1]: While optimizer accumulation is noisy compared to gradient accumulation, SGD's optimizer accumulation results are significantly nosier then all other optimizers. \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 1b100a7..0b40a5a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -88,6 +88,7 @@ extra_css: nav: - Low Precision Training: kahan_summation.md - Gradient Release: gradient_release.md + - Optimizer Accumulation: optimizer_accumulation.md - ForEach Optimizers: foreach.md - Fully Decoupled Weight Decay: fully_decoupled_weight_decay.md - Which Optimizer?: which_optimizer.md diff --git a/optimi/adam.py b/optimi/adam.py index 4b92e51..15ade4f 100644 --- a/optimi/adam.py +++ b/optimi/adam.py @@ -169,6 +169,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=group["foreach"], gradient_release=False, + optimizer_accumulation=False, ) else: state = self.state[param] @@ -193,6 +194,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=False, gradient_release=True, + optimizer_accumulation=self._optimizer_accumulation, ) return loss @@ -217,6 +219,7 @@ def adam( kahan_sum: bool = False, foreach: bool = False, gradient_release: bool = False, + optimizer_accumulation: bool = False, ): """Functional API to apply an Adam or AdamW optimization step. @@ -240,6 +243,7 @@ def adam( kahan_sum: Enables Kahan summation for low precision parameters foreach: Enables the faster foreach implementation gradient_release: Fuses optimizer step as part of the parameter's backward pass + optimizer_accumulation: Accumulate gradients into state during gradient release step """ # calculate debiased beta hat & complement terms step.add_(1) @@ -276,6 +280,7 @@ def adam( eps=eps, decouple_wd=(decouple_wd or decouple_lr), kahan_sum=kahan_sum, + update_parameters=(not optimizer_accumulation), ) @@ -293,6 +298,7 @@ def _single_adam( eps: float, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): for i, param in enumerate(params): grad = grads[i] @@ -313,6 +319,7 @@ def _single_adam( eps=eps, decouple_wd=decouple_wd, kahan_sum=kahan_sum, + update_parameters=update_parameters, ) @@ -330,9 +337,10 @@ def _single_param_adam( eps: float, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): # decoupled weight decay, fully decoupled weight decay, or L2 weight decay - if weight_decay != 0: + if weight_decay != 0 and update_parameters: if decouple_wd: param.mul_(weight_decay) else: @@ -342,19 +350,20 @@ def _single_param_adam( exp_avg.lerp_(grad, weight=beta1_comp) exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1 - beta2_hat) - if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # Adam step - kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr) + if update_parameters: + if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: + # Adam step + kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr) - # update weights with kahan compensation using grad as temp buffer - grad.copy_(param.detach()) - param.add_(kahan_comp) + # update weights with kahan compensation using grad as temp buffer + grad.copy_(param.detach()) + param.add_(kahan_comp) - # save error back to kahan compensation for next iteration - kahan_comp.add_(grad.sub_(param)) - else: - # Adam step - param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr) + # save error back to kahan compensation for next iteration + kahan_comp.add_(grad.sub_(param)) + else: + # Adam step + param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr) def _foreach_adam( @@ -371,6 +380,7 @@ def _foreach_adam( eps: float, decouple_wd: bool, kahan_sum: bool = False, + **kwargs, ): grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, kahan_comps]) for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_exp_avg_sqs, dev_kahan_comps), _) in grouped_tensors.items(): diff --git a/optimi/adamw.py b/optimi/adamw.py index 740ab34..1f580f5 100644 --- a/optimi/adamw.py +++ b/optimi/adamw.py @@ -91,6 +91,7 @@ def adamw( kahan_sum: bool = False, foreach: bool = False, gradient_release: bool = False, + optimizer_accumulation: bool = False, ): """Functional API to apply an AdamW optimization step. @@ -113,6 +114,7 @@ def adamw( kahan_sum: Enables Kahan summation for low precision `params` foreach: Enables the faster foreach implementation gradient_release: Fuses optimizer step as part of the parameter's backward pass + optimizer_accumulation: Accumulate gradients into state during gradient release step """ adam( params=params, @@ -132,4 +134,5 @@ def adamw( kahan_sum=kahan_sum, foreach=foreach, gradient_release=gradient_release, + optimizer_accumulation=optimizer_accumulation, ) diff --git a/optimi/adan.py b/optimi/adan.py index 996823c..2acae95 100644 --- a/optimi/adan.py +++ b/optimi/adan.py @@ -190,6 +190,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=group["foreach"], gradient_release=False, + optimizer_accumulation=False, ) else: state = self.state[param] @@ -216,6 +217,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=False, gradient_release=True, + optimizer_accumulation=self._optimizer_accumulation, ) return loss @@ -243,6 +245,7 @@ def adan( kahan_sum: bool = False, foreach: bool = False, gradient_release: bool = False, + optimizer_accumulation: bool = False, ): """Functional API to apply a Adan optimization step. @@ -269,6 +272,7 @@ def adan( kahan_sum: Enables Kahan summation for low precision parameters foreach: Enables the faster foreach implementation gradient_release: Fuses optimizer step as part of the parameter's backward pass + optimizer_accumulation: Accumulate gradients into state during gradient release step """ # calculate debiased beta hat & complement terms step.add_(1) @@ -315,6 +319,7 @@ def adan( weight_decay=weight_decay, adam_wd=adam_wd, kahan_sum=kahan_sum, + update_parameters=(not optimizer_accumulation), ) @@ -336,6 +341,7 @@ def _single_adan( weight_decay: float, adam_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): for i, param in enumerate(params): grad = grads[i] @@ -362,6 +368,7 @@ def _single_adan( weight_decay=weight_decay, adam_wd=adam_wd, kahan_sum=kahan_sum, + update_parameters=update_parameters, ) @@ -383,6 +390,7 @@ def _single_param_adan( weight_decay: float, adam_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): # difference between current & previous gradients, prev_grad is negated in last step prev_grad.add_(grad) @@ -400,32 +408,33 @@ def _single_param_adan( # set next step's prior_grad as negated current grad prev_grad.copy_(grad).mul_(-1) - # calculate 1/η_k using prev_grad as buffer. LR is multiplied in Adan step - denom = exp_avg_sq.sqrt().add_(eps) + if update_parameters: + # calculate 1/η_k using prev_grad as buffer. LR is multiplied in Adan step + denom = exp_avg_sq.sqrt().add_(eps) - # Adam-style weight decay - if adam_wd and weight_decay != 0: - param.mul_(weight_decay) + # Adam-style weight decay + if adam_wd and weight_decay != 0: + param.mul_(weight_decay) - if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # Adan step - kahan_comp.addcdiv_(exp_avg, denom, value=-lr) - kahan_comp.addcdiv_(exp_avg_diff, denom, value=-lr * beta2) + if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: + # Adan step + kahan_comp.addcdiv_(exp_avg, denom, value=-lr) + kahan_comp.addcdiv_(exp_avg_diff, denom, value=-lr * beta2) - # update weights with kahan compensation using grad as temp buffer - grad.copy_(param.detach()) - param.add_(kahan_comp) + # update weights with kahan compensation using grad as temp buffer + grad.copy_(param.detach()) + param.add_(kahan_comp) - # save error back to kahan compensation for next iteration - kahan_comp.add_(grad.sub_(param)) - else: - # Adan step - param.addcdiv_(exp_avg, denom, value=-lr) - param.addcdiv_(exp_avg_diff, denom, value=-lr * beta2) + # save error back to kahan compensation for next iteration + kahan_comp.add_(grad.sub_(param)) + else: + # Adan step + param.addcdiv_(exp_avg, denom, value=-lr) + param.addcdiv_(exp_avg_diff, denom, value=-lr * beta2) - # Adan-style weight decay - if not adam_wd and weight_decay != 0: - param.div_(weight_decay) + # Adan-style weight decay + if not adam_wd and weight_decay != 0: + param.div_(weight_decay) def _foreach_adan( @@ -446,6 +455,7 @@ def _foreach_adan( weight_decay: float, adam_wd: bool, kahan_sum: bool = False, + **kwargs, ): grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, exp_avg_diffs, prev_grads, kahan_comps]) for (_, dtype), ( diff --git a/optimi/lion.py b/optimi/lion.py index 3e92c33..164f35a 100644 --- a/optimi/lion.py +++ b/optimi/lion.py @@ -146,6 +146,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=group["foreach"], gradient_release=False, + optimizer_accumulation=False, ) else: state = self.state[param] @@ -166,6 +167,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=False, gradient_release=True, + optimizer_accumulation=self._optimizer_accumulation, ) return loss @@ -186,6 +188,7 @@ def lion( kahan_sum: bool = False, foreach: bool = False, gradient_release: bool = False, + optimizer_accumulation: bool = False, ): """Functional API to apply a Lion optimization step. @@ -205,6 +208,7 @@ def lion( kahan_sum: Enables Kahan summation for low precision `params` foreach: Enables the faster foreach implementation gradient_release: Fuses optimizer step as part of the parameter's backward pass + optimizer_accumulation: Accumulate gradients into state during gradient release step """ # calculate decoupled weight decay or fully decoupled weight decay if weight_decay != 0: @@ -237,6 +241,7 @@ def lion( beta2_comp=beta2_comp, weight_decay=weight_decay, kahan_sum=kahan_sum, + update_parameters=(not optimizer_accumulation), ) @@ -251,6 +256,7 @@ def _single_lion( beta2_comp: float, weight_decay: float, kahan_sum: bool = False, + update_parameters: bool = True, ): for i, param in enumerate(params): grad = grads[i] @@ -267,6 +273,7 @@ def _single_lion( beta2_comp=beta2_comp, weight_decay=weight_decay, kahan_sum=kahan_sum, + update_parameters=update_parameters, ) @@ -281,30 +288,33 @@ def _single_param_lion( beta2_comp: float, weight_decay: float, kahan_sum: bool = False, + update_parameters: bool = True, ): # decoupled weight decay or fully decoupled weight decay - if weight_decay != 0: + if weight_decay != 0 and update_parameters: param.mul_(weight_decay) # parameter update value - update = exp_avg.lerp(grad, weight=beta1_comp).sign_() + if update_parameters: + update = exp_avg.lerp(grad, weight=beta1_comp).sign_() # update gradient moving average exp_avg.lerp_(grad, weight=beta2_comp) - if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # Lion step - kahan_comp.add_(update, alpha=-lr) + if update_parameters: + if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: + # Lion step + kahan_comp.add_(update, alpha=-lr) - # update weights with kahan compensation using grad as temp buffer - grad.copy_(param.detach()) - param.add_(kahan_comp) + # update weights with kahan compensation using grad as temp buffer + grad.copy_(param.detach()) + param.add_(kahan_comp) - # save error back to kahan compensation for next iteration - kahan_comp.add_(grad.sub_(param)) - else: - # Lion step - param.add_(update, alpha=-lr) + # save error back to kahan compensation for next iteration + kahan_comp.add_(grad.sub_(param)) + else: + # Lion step + param.add_(update, alpha=-lr) def _foreach_lion( @@ -318,6 +328,7 @@ def _foreach_lion( beta2_comp: float, weight_decay: float, kahan_sum: bool = False, + **kwargs, ): grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, kahan_comps]) for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_kahan_comps), _) in grouped_tensors.items(): diff --git a/optimi/optimizer.py b/optimi/optimizer.py index 7f488a8..783fc22 100644 --- a/optimi/optimizer.py +++ b/optimi/optimizer.py @@ -40,6 +40,9 @@ def __init__(self, params: Iterable[Tensor] | Iterable[dict], defaults: dict[str super().__init__(params, defaults) + # by default perform the normal parameter update step + self._optimizer_accumulation = False + # if gradient_release is enabled, disable foreach step so normal optimizer step won't error if self.defaults["gradient_release"]: self.defaults["foreach"] = False @@ -48,6 +51,16 @@ def __init__(self, params: Iterable[Tensor] | Iterable[dict], defaults: dict[str for p in group["params"]: self.state[p]["group"] = group + @property + def optimizer_accumulation(self) -> bool: + "Accumulate gradients in optimizer states during gradient release instead of a full step." + return self._optimizer_accumulation + + @optimizer_accumulation.setter + def optimizer_accumulation(self, optimizer_accumulation: bool): + "Accumulate gradients in optimizer states during gradient release instead of a full step." + self._optimizer_accumulation = optimizer_accumulation + def step(self, closure: Callable | None = None, param: Tensor | None = None): """Performs a single optimization step on the whole model or individual parameter. diff --git a/optimi/radam.py b/optimi/radam.py index 9b99baa..285a23a 100644 --- a/optimi/radam.py +++ b/optimi/radam.py @@ -170,6 +170,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=group["foreach"], gradient_release=False, + optimizer_accumulation=False, ) else: state = self.state[param] @@ -194,6 +195,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=False, gradient_release=True, + optimizer_accumulation=self._optimizer_accumulation, ) return loss @@ -218,6 +220,7 @@ def radam( kahan_sum: bool = False, foreach: bool = False, gradient_release: bool = False, + optimizer_accumulation: bool = False, ): """Functional API to apply an RAdam optimization step. @@ -241,6 +244,7 @@ def radam( kahan_sum: Enables Kahan summation for low precision parameters foreach: Enables the faster foreach implementation gradient_release: Fuses optimizer step as part of the parameter's backward pass + optimizer_accumulation: Accumulate gradients into state during gradient release step """ # calculate debiased beta hat & complement terms step.add_(1) @@ -288,6 +292,7 @@ def radam( rect=rect, decouple_wd=(decouple_wd or decouple_lr), kahan_sum=kahan_sum, + update_parameters=(not optimizer_accumulation), ) @@ -306,6 +311,7 @@ def _single_radam( rect: float | None, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): for i, param in enumerate(params): grad = grads[i] @@ -327,6 +333,7 @@ def _single_radam( rect=rect, decouple_wd=decouple_wd, kahan_sum=kahan_sum, + update_parameters=update_parameters, ) @@ -345,9 +352,10 @@ def _single_param_radam( rect: float | None, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): # decoupled weight decay, fully decoupled weight decay, or L2 weight decay - if weight_decay != 0: + if weight_decay != 0 and update_parameters: if decouple_wd: param.mul_(weight_decay) else: @@ -357,25 +365,26 @@ def _single_param_radam( exp_avg.lerp_(grad, weight=beta1_comp) exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1 - beta2_hat) - if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # RAdam step - if rect is not None: - kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect) - else: - kahan_comp.add_(exp_avg, alpha=-lr) + if update_parameters: + if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: + # RAdam step + if rect is not None: + kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect) + else: + kahan_comp.add_(exp_avg, alpha=-lr) - # update weights with kahan compensation using grad as temp buffer - grad.copy_(param.detach()) - param.add_(kahan_comp) + # update weights with kahan compensation using grad as temp buffer + grad.copy_(param.detach()) + param.add_(kahan_comp) - # save error back to kahan compensation for next iteration - kahan_comp.add_(grad.sub_(param)) - else: - # RAdam step - if rect is not None: - param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect) + # save error back to kahan compensation for next iteration + kahan_comp.add_(grad.sub_(param)) else: - param.add_(exp_avg, alpha=-lr) + # RAdam step + if rect is not None: + param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect) + else: + param.add_(exp_avg, alpha=-lr) def _foreach_radam( @@ -393,6 +402,7 @@ def _foreach_radam( rect: float | None, decouple_wd: bool, kahan_sum: bool = False, + **kwargs, ): grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, kahan_comps]) for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_exp_avg_sqs, dev_kahan_comps), _) in grouped_tensors.items(): diff --git a/optimi/ranger.py b/optimi/ranger.py index c5a003d..b8f8f75 100644 --- a/optimi/ranger.py +++ b/optimi/ranger.py @@ -182,6 +182,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=group["foreach"], gradient_release=False, + optimizer_accumulation=False, ) else: state = self.state[param] @@ -209,6 +210,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=False, gradient_release=True, + optimizer_accumulation=self._optimizer_accumulation, ) return loss @@ -236,6 +238,7 @@ def ranger( kahan_sum: bool = False, foreach: bool = False, gradient_release: bool = False, + optimizer_accumulation: bool = False, ): """Functional API to apply a Ranger optimization step. @@ -262,6 +265,7 @@ def ranger( kahan_sum: Enables Kahan summation for low precision parameters foreach: Enables the faster foreach implementation gradient_release: Fuses optimizer step as part of the parameter's backward pass + optimizer_accumulation: Accumulate gradients into state during gradient release step """ # calculate debiased beta hat & complement terms step.add_(1) @@ -313,6 +317,7 @@ def ranger( step=step, decouple_wd=(decouple_wd or decouple_lr), kahan_sum=kahan_sum, + update_parameters=(not optimizer_accumulation), ) @@ -335,6 +340,7 @@ def _single_ranger( step: int, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): for i, param in enumerate(params): grad = grads[i] @@ -361,6 +367,7 @@ def _single_ranger( step=step, decouple_wd=decouple_wd, kahan_sum=kahan_sum, + update_parameters=update_parameters, ) @@ -383,9 +390,10 @@ def _single_param_ranger( step: int, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): # decoupled weight decay, fully decoupled weight decay, or L2 weight decay - if weight_decay != 0: + if weight_decay != 0 and update_parameters: if decouple_wd: param.mul_(weight_decay) else: @@ -395,43 +403,44 @@ def _single_param_ranger( exp_avg.lerp_(grad, weight=beta1_comp) exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1 - beta2_hat) - if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # RAdam step - if rect is not None: - kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect) - else: - kahan_comp.add_(exp_avg, alpha=-lr) + if update_parameters: + if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: + # RAdam step + if rect is not None: + kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect) + else: + kahan_comp.add_(exp_avg, alpha=-lr) - # update weights with kahan compensation using grad as temp buffer - grad.copy_(param.detach()) - param.add_(kahan_comp) + # update weights with kahan compensation using grad as temp buffer + grad.copy_(param.detach()) + param.add_(kahan_comp) - # save error back to kahan compensation for next iteration - kahan_comp.add_(grad.sub_(param)) + # save error back to kahan compensation for next iteration + kahan_comp.add_(grad.sub_(param)) - # Lookahead step - if step % k == 0: - kahan_comp.add_(param.sub_(la_param), alpha=alpha) + # Lookahead step + if step % k == 0: + kahan_comp.add_(param.sub_(la_param), alpha=alpha) - # update weights with kahan compensation using grad as temp buffer - grad.copy_(la_param.detach()) - la_param.add_(kahan_comp) + # update weights with kahan compensation using grad as temp buffer + grad.copy_(la_param.detach()) + la_param.add_(kahan_comp) - # save error back to kahan compensation for next iteration - kahan_comp.add_(grad.sub_(la_param), alpha=alpha) + # save error back to kahan compensation for next iteration + kahan_comp.add_(grad.sub_(la_param), alpha=alpha) - param.copy_(la_param) - else: - # RAdam step - if rect is not None: - param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect) + param.copy_(la_param) else: - param.add_(exp_avg, alpha=-lr) + # RAdam step + if rect is not None: + param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect) + else: + param.add_(exp_avg, alpha=-lr) - # Lookahead step - if step % k == 0: - la_param.add_(param.sub(la_param), alpha=alpha) - param.copy_(la_param) + # Lookahead step + if step % k == 0: + la_param.add_(param.sub(la_param), alpha=alpha) + param.copy_(la_param) def _foreach_ranger( @@ -453,6 +462,7 @@ def _foreach_ranger( step: int, decouple_wd: bool, kahan_sum: bool = False, + **kwargs, ): grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, la_params, kahan_comps]) for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_exp_avg_sqs, dev_la_params, dev_kahan_comps), _) in grouped_tensors.items(): diff --git a/optimi/sgd.py b/optimi/sgd.py index c9810eb..b56291b 100644 --- a/optimi/sgd.py +++ b/optimi/sgd.py @@ -159,6 +159,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=group["foreach"], gradient_release=False, + optimizer_accumulation=False, ) else: state = self.state[param] @@ -180,6 +181,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=False, gradient_release=True, + optimizer_accumulation=self._optimizer_accumulation, ) return loss @@ -201,6 +203,7 @@ def sgd( kahan_sum: bool = False, foreach: bool = False, gradient_release: bool = False, + optimizer_accumulation: bool = False, ): """Functional API to apply a SGD or SGDW optimization step. @@ -221,6 +224,7 @@ def sgd( kahan_sum: Enables Kahan summation for low precision `params` foreach: Enables the faster foreach implementation gradient_release: Fuses optimizer step as part of the parameter's backward pass + optimizer_accumulation: Accumulate gradients into state during gradient release step """ # calculate decoupled weight decay or fully decoupled weight decay if weight_decay != 0: @@ -250,6 +254,7 @@ def sgd( dampening=dampening, decouple_wd=(decouple_wd or decouple_lr), kahan_sum=kahan_sum, + update_parameters=(not optimizer_accumulation), ) @@ -265,6 +270,7 @@ def _single_sgd( dampening: bool, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): for i, param in enumerate(params): grad = grads[i] @@ -282,6 +288,7 @@ def _single_sgd( dampening=dampening, decouple_wd=decouple_wd, kahan_sum=kahan_sum, + update_parameters=update_parameters, ) @@ -297,23 +304,27 @@ def _single_param_sgd( dampening: bool, decouple_wd: bool, kahan_sum: bool = False, + update_parameters: bool = True, ): # decoupled weight decay, fully decoupled weight decay, or L2 weight decay - if weight_decay != 0: + if weight_decay != 0 and update_parameters: if decouple_wd: param.mul_(weight_decay) else: grad.add_(param, alpha=weight_decay) if momentum != 0: - # SGD Momentum + # SGD with Momentum if dampening: exp_avg.lerp_(grad, weight=1 - momentum) else: exp_avg.mul_(momentum).add_(grad) + else: + exp_avg = grad + if update_parameters: if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # SGD with Momentum step + # SGD step (regular step exp_agv = grad) kahan_comp.add_(exp_avg, alpha=-lr) # update weights with kahan compensation using grad as temp buffer @@ -323,22 +334,8 @@ def _single_param_sgd( # save error back to kahan compensation for next iteration kahan_comp.add_(grad.sub_(param)) else: - # SGD with Momentum step + # SGD step (regular step exp_agv = grad) param.add_(exp_avg, alpha=-lr) - else: - if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # SGD step - kahan_comp.add_(grad, alpha=-lr) - - # update weights with kahan compensation using grad as temp buffer - grad.copy_(param.detach()) - param.add_(kahan_comp) - - # save error back to kahan compensation for next iteration - kahan_comp.add_(grad.sub_(param)) - else: - # SGD step - param.add_(grad, alpha=-lr) def _foreach_sgd( @@ -353,6 +350,7 @@ def _foreach_sgd( dampening: bool, decouple_wd: bool, kahan_sum: bool = False, + **kwargs, ): grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, kahan_comps]) for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_kahan_comps), _) in grouped_tensors.items(): @@ -364,39 +362,26 @@ def _foreach_sgd( torch._foreach_add_(dev_grads, dev_params, alpha=weight_decay) if momentum != 0: - # SGD Momentum + # SGD with Momentum if dampening: torch._foreach_lerp_(dev_exp_avgs, dev_grads, weight=1 - momentum) else: torch._foreach_mul_(dev_exp_avgs, scalar=momentum) torch._foreach_add_(dev_exp_avgs, dev_grads, alpha=1) + else: + dev_exp_avgs = dev_grads - if kahan_sum and dtype in [torch.float16, torch.bfloat16]: - # SGD with Momentum step - torch._foreach_add_(dev_kahan_comps, dev_exp_avgs, alpha=-lr) + if kahan_sum and dtype in [torch.float16, torch.bfloat16]: + # SGD step (regular step exp_agv = grad) + torch._foreach_add_(dev_kahan_comps, dev_exp_avgs, alpha=-lr) - # update weights with kahan compensation using dev_grads as temp buffer - torch._foreach_copy_(dev_grads, dev_params) - torch._foreach_add_(dev_params, dev_kahan_comps, alpha=1) + # update weights with kahan compensation using dev_grads as temp buffer + torch._foreach_copy_(dev_grads, dev_params) + torch._foreach_add_(dev_params, dev_kahan_comps, alpha=1) - # save error back to kahan compensation for next iteration - torch._foreach_sub_(dev_grads, dev_params, alpha=1) - torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1) - else: - # SGD with Momentum step - torch._foreach_add_(dev_params, dev_exp_avgs, alpha=-lr) + # save error back to kahan compensation for next iteration + torch._foreach_sub_(dev_grads, dev_params, alpha=1) + torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1) else: - if kahan_sum and dtype in [torch.float16, torch.bfloat16]: - # SGD step - torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=-lr) - - # update weights with kahan compensation using dev_grads as temp buffer - torch._foreach_copy_(dev_grads, dev_params) - torch._foreach_add_(dev_params, dev_kahan_comps, alpha=1) - - # save error back to kahan compensation for next iteration - torch._foreach_sub_(dev_grads, dev_params, alpha=1) - torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1) - else: - # SGD step - torch._foreach_add_(dev_params, dev_grads, alpha=-lr) + # SGD step (regular step exp_agv = grad) + torch._foreach_add_(dev_params, dev_exp_avgs, alpha=-lr) diff --git a/optimi/stableadamw.py b/optimi/stableadamw.py index 902137d..1dd3f06 100644 --- a/optimi/stableadamw.py +++ b/optimi/stableadamw.py @@ -170,6 +170,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=group["foreach"], gradient_release=False, + optimizer_accumulation=False, ) else: state = self.state[param] @@ -194,6 +195,7 @@ def step(self, closure: Callable | None = None, param: Tensor | None = None): kahan_sum=group["kahan_sum"], foreach=False, gradient_release=True, + optimizer_accumulation=self._optimizer_accumulation, ) return loss @@ -218,6 +220,7 @@ def stableadamw( kahan_sum: bool = False, foreach: bool = False, gradient_release: bool = False, + optimizer_accumulation: bool = False, ): """Functional API to apply a StableAdamW optimization step. @@ -241,6 +244,7 @@ def stableadamw( kahan_sum: Enables Kahan summation for low precision parameters foreach: Enables the faster foreach implementation gradient_release: Fuses optimizer step as part of the parameter's backward pass + optimizer_accumulation: Accumulate gradients into state during gradient release step """ # calculate debiased beta hat & complement terms step.add_(1) @@ -272,6 +276,7 @@ def stableadamw( decouple_lr=decouple_lr, max_lr=max_lr, kahan_sum=kahan_sum, + update_parameters=(not optimizer_accumulation), ) @@ -291,6 +296,7 @@ def _single_stableadamw( decouple_lr: bool, max_lr: float | None = None, kahan_sum: bool = False, + update_parameters: bool = True, ): for i, param in enumerate(params): grad = grads[i] @@ -314,6 +320,7 @@ def _single_stableadamw( decouple_lr=decouple_lr, max_lr=max_lr, kahan_sum=kahan_sum, + update_parameters=update_parameters, ) @@ -333,38 +340,40 @@ def _single_param_stableadamw( decouple_lr: bool, max_lr: float | None = None, kahan_sum: bool = False, + update_parameters: bool = True, ): # update gradient moving averages with debiased betas exp_avg.lerp_(grad, weight=beta1_comp) exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1 - beta2_hat) - # compute per tensor RMS stabilization term - rms = grad.pow(2).div_(exp_avg_sq.maximum(eps_sq)).mean().sqrt() + if update_parameters: + # compute per tensor RMS stabilization term + rms = grad.pow(2).div_(exp_avg_sq.maximum(eps_sq)).mean().sqrt() - # calculate RMS stabilized learning rate - lr = lr / max(1, rms.item()) + # calculate RMS stabilized learning rate + lr = lr / max(1, rms.item()) - # decoupled weight decay or fully decoupled weight decay - if weight_decay != 0: - if decouple_lr: - weight_decay = 1 - (lr / max_lr) * weight_decay - else: - weight_decay = 1 - lr * weight_decay - param.mul_(weight_decay) + # decoupled weight decay or fully decoupled weight decay + if weight_decay != 0: + if decouple_lr: + weight_decay = 1 - (lr / max_lr) * weight_decay + else: + weight_decay = 1 - lr * weight_decay + param.mul_(weight_decay) - if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: - # Adam step - kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr) + if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]: + # Adam step + kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr) - # update weights with kahan compensation using grad as temp buffer - grad.copy_(param.detach()) - param.add_(kahan_comp) + # update weights with kahan compensation using grad as temp buffer + grad.copy_(param.detach()) + param.add_(kahan_comp) - # save error back to kahan compensation for next iteration - kahan_comp.add_(grad.sub_(param)) - else: - # Adam step - param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr) + # save error back to kahan compensation for next iteration + kahan_comp.add_(grad.sub_(param)) + else: + # Adam step + param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr) def _foreach_stableadamw( @@ -383,6 +392,7 @@ def _foreach_stableadamw( decouple_lr: bool, max_lr: float | None = None, kahan_sum: bool = False, + **kwargs, ): grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, eps_sqs, kahan_comps]) for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_exp_avg_sqs, dev_eps_sqs, dev_kahan_comps), _) in grouped_tensors.items(): diff --git a/pyproject.toml b/pyproject.toml index ea2187d..1352337 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,9 +20,9 @@ classifiers = [ dependencies = ["torch>=1.13", "packaging>=21.3"] [project.optional-dependencies] -test = ["pytest>=7.4.3", "ruff>=0.1.3", "pytest-md>=0.2.0", "numpy>=1.23"] -docs = ["mkdocs-material>=9.4.7", "mkdocstrings>=0.23.0", "mkdocstrings-python>=1.7.3", "black>=23.10.1", "mkdocs-caption>=0.0.9"] -dev = ["pytest>=7.4.3", "ruff>=0.1.3", "mkdocs-material>=9.4.7", "mkdocstrings>=0.23.0", "mkdocstrings-python>=1.7.3", "black>=23.10.1", "mkdocs-caption>=0.0.9"] +test = ["pytest>=8.1.1", "ruff>=0.3.2", "pytest-md>=0.2.0", "numpy>=1.23"] +docs = ["mkdocs-material>=9.4.7", "mkdocstrings>=0.24.1", "mkdocstrings-python>=1.8.0", "black>=24.2.0", "mkdocs-caption>=1.0.0"] +dev = ["pytest>=8.1.1", "ruff>=0.3.2", "mkdocs-material>=9.4.7", "mkdocstrings>=0.24.1", "mkdocstrings-python>=1.8.0", "black>=24.2.0", "mkdocs-caption>=1.0.0"] [project.urls] "Homepage" = "https://optimi.benjaminwarner.dev" @@ -38,27 +38,23 @@ packages = ["optimi"] [tool.pytest.ini_options] testpaths = ["tests"] -markers = ["cpu", "cuda"] +markers = ["cpu", "cuda", "adam", "adan", "lion", "radam", "ranger", "sgd", "stableadam"] [tool.ruff] line-length = 140 -select = ["E", "W", "F", "I", "D", "UP"] -extend-ignore = ["D100", "D107", "D206", "D300", "E111", "E114", "E117"] extend-exclude = ["tests", "docs"] src = ["optimi"] -[tool.ruff.extend-per-file-ignores] -"__init__.py" = ["D104", "F401", "I002"] -"utils.py" = ["I002"] - [tool.ruff.format] exclude = ["tests", "docs"] -[tool.ruff.isort] -required-imports = ["from __future__ import annotations"] - -[tool.ruff.pycodestyle] -max-doc-length = 100 +[tool.ruff.lint] +select = ["E", "W", "F", "I", "D", "UP"] +extend-ignore = ["D100", "D107", "D206", "D300", "E111", "E114", "E117"] +isort.required-imports = ["from __future__ import annotations"] +pycodestyle.max-doc-length = 100 +pydocstyle.convention = "google" -[tool.ruff.pydocstyle] -convention = "google" \ No newline at end of file +[tool.ruff.lint.extend-per-file-ignores] +"__init__.py" = ["D104", "F401", "I002"] +"utils.py" = ["I002"] \ No newline at end of file diff --git a/tests/adam_test.py b/tests/adam_test.py index 9d7c37a..501fc29 100644 --- a/tests/adam_test.py +++ b/tests/adam_test.py @@ -8,7 +8,7 @@ from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype, cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype) + gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation) optimizers = {} @@ -36,6 +36,7 @@ cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] @pytest.mark.cpu +@pytest.mark.adam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names) def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer) @@ -46,16 +47,29 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.adam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=80) + cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.adam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), - iterations=80, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file + framework_opt_step=torch.rand(1).item() > 0.5) + + +@pytest.mark.cuda +@pytest.mark.adam +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) +def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): + if optim_name in ["adam_l2"]: + pytest.skip("Skip tests for Adam with L2 weight decay.") + optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), + framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file diff --git a/tests/adan_test.py b/tests/adan_test.py index 0eeb5e7..ff9c3ce 100644 --- a/tests/adan_test.py +++ b/tests/adan_test.py @@ -8,7 +8,7 @@ from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype, cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype) + gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation) @@ -34,6 +34,7 @@ cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] @pytest.mark.cpu +@pytest.mark.adan @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names) def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer) @@ -44,17 +45,28 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.adan @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): # Adan bfloat16 updates are noisier, so GPU currently doesn't use longer test iterations run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=20 if gtype==torch.bfloat16 else 80) + cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.adan @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), - iterations=80, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file + framework_opt_step=torch.rand(1).item() > 0.5) + + +@pytest.mark.cuda +@pytest.mark.adan +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) +def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): + optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), + framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file diff --git a/tests/anyadam_test.py b/tests/anyadam_test.py index 5e84637..de64a16 100644 --- a/tests/anyadam_test.py +++ b/tests/anyadam_test.py @@ -30,6 +30,7 @@ cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] @pytest.mark.cpu +@pytest.mark.adam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names) def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer, any_precision=True) @@ -41,6 +42,7 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.adam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=80, any_precision=True) \ No newline at end of file diff --git a/tests/lion_test.py b/tests/lion_test.py index 468e080..f02fb4b 100644 --- a/tests/lion_test.py +++ b/tests/lion_test.py @@ -8,7 +8,7 @@ from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype, cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype) + gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation) @@ -31,6 +31,7 @@ cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] @pytest.mark.cpu +@pytest.mark.lion @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names) def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer) @@ -41,6 +42,7 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.lion @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=80) @@ -51,7 +53,16 @@ def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, f cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.lion @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), - iterations=80, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file + framework_opt_step=torch.rand(1).item() > 0.5) + + +@pytest.mark.cuda +@pytest.mark.lion +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) +def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): + optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), + framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file diff --git a/tests/optimizer_test.py b/tests/optimizer_test.py index 0ae8b6f..521d2b3 100644 --- a/tests/optimizer_test.py +++ b/tests/optimizer_test.py @@ -122,13 +122,13 @@ def run_optimizer(optimizers:dict, dim1:int, dim2:int, gtype:torch.dtype, optim_ def gradient_release(optimizers:dict, dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, - ftype:str, device:torch.device, iterations:int=20, framework_opt_step:bool=False): + ftype:str, device:torch.device, iterations:int=80, framework_opt_step:bool=False): def optimizer_hook(parameter) -> None: torch_optimizers[parameter].step() torch_optimizers[parameter].zero_grad() - # Since Lion & Adan can have noisy updates, allow up to 10 errors - max_error_count = 10 + # Since Lion & Adan can have noisy updates, allow up to 12 errors + max_error_count = 12 if dtype == torch.float32: atol, rtol = 1e-6, 1e-5 @@ -192,16 +192,79 @@ def optimizer_hook(parameter) -> None: optimi_optimizer.step() optimi_optimizer.zero_grad() - assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=rtol, atol=atol, max_error_count=max_error_count, name='PyTorch-PyTorch: ') - assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=rtol, atol=atol, max_error_count=max_error_count, name='PyTorch-PyTorch: ') - assert_most_approx_close(m1.fc1.weight, m3.fc1.weight, rtol=rtol, atol=atol, max_error_count=max_error_count, name='PyTorch-Optimi: ') - assert_most_approx_close(m1.fc2.weight, m3.fc2.weight, rtol=rtol, atol=atol, max_error_count=max_error_count, name='PyTorch-Optimi: ') + assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=rtol, atol=atol, + max_error_count=max_error_count, name='PyTorch-PyTorch: ') + assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=rtol, atol=atol, + max_error_count=max_error_count, name='PyTorch-PyTorch: ') + assert_most_approx_close(m1.fc1.weight, m3.fc1.weight, rtol=rtol, atol=atol, + max_error_count=max_error_count, name='PyTorch-Optimi: ') + assert_most_approx_close(m1.fc2.weight, m3.fc2.weight, rtol=rtol, atol=atol, + max_error_count=max_error_count, name='PyTorch-Optimi: ') for h in pytorch_hooks: h.remove() remove_gradient_release(m3) +def optimizer_accumulation(optimizers:dict, dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, + ftype:str, device:torch.device, iterations:int=80, framework_opt_step:bool=False): + # Since optimizer accumulation approximates gradient accumulation, the tolerances + # compared to normal optimizers are quite high despite the low number of iterations + # SGD will randomly error out unless max_error_rate is 30%, other optimizers only need 3.5% + max_error_rate = 0.30 if 'sgd' in optim_name else 0.035 + atol, rtol = 1e-2, 1e-2 + + m1 = MLP(dim1, dim2, device=device, dtype=dtype) + m2 = MLP(dim1, dim2, device=device, dtype=dtype) + m2.load_state_dict(m1.state_dict()) + + regular_optimizer = load_optimizer(m1.parameters(), optimizers, optim_name, 0, ftype) + + + # Optimim Method + # add the gradient release flag to the optimizer kwargs + optimizers[optim_name][1]['kwargs']['gradient_release'] = True + optimi_optimizer = load_optimizer(m2.parameters(), optimizers, optim_name, 1, ftype) + + prepare_for_gradient_release(m2, optimi_optimizer) + + gradient_accumulation_steps = 4 + + # Training loop + for i in range(iterations): + input1 = torch.randn(1, dim1, device=device, dtype=dtype) + input2 = input1.clone() + target1 = torch.randn(1, 1, device=device, dtype=dtype) + target2 = target1.clone() + + optimi_optimizer.optimizer_accumulation = (i+1) % gradient_accumulation_steps != 0 + + output1 = m1(input1) + output2 = m2(input2) + + loss1 = torch.nn.functional.mse_loss(output1, target1) + loss2 = torch.nn.functional.mse_loss(output2, target2) + + loss1.backward() + loss2.backward() + + if not optimi_optimizer.optimizer_accumulation: + regular_optimizer.step() + regular_optimizer.zero_grad() + + # simulates using an optimi gradient release optimizer in a framework + # where the optimizer step and zero_grad cannot be disabled. + if framework_opt_step: + optimi_optimizer.step() + optimi_optimizer.zero_grad() + + # unlike other tests, compare that the weights are in the same approximate range at the end of training + assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=rtol, atol=atol, max_error_rate=max_error_rate) + assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=rtol, atol=atol, max_error_rate=max_error_rate) + + remove_gradient_release(m2) + + buffer = io.BytesIO() diff --git a/tests/radam_test.py b/tests/radam_test.py index 50cddfe..ec6dad9 100644 --- a/tests/radam_test.py +++ b/tests/radam_test.py @@ -8,7 +8,7 @@ from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype, cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype) + gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation) # PyTorch's RAdam adds epsilon before debiasing V while Optimi debases before. # RAdam tests with a smaller epsilon then other optimizers to prevent numerical divergances. @@ -36,6 +36,7 @@ cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] @pytest.mark.cpu +@pytest.mark.radam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names) def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer) @@ -46,6 +47,7 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.radam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=80) @@ -56,7 +58,18 @@ def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, f cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.radam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), - iterations=80, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file + framework_opt_step=torch.rand(1).item() > 0.5) + + +@pytest.mark.cuda +@pytest.mark.radam +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) +def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): + if optim_name in ["radam_l2"]: + pytest.skip("Skip tests for RAdam with L2 weight decay.") + optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), + framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file diff --git a/tests/ranger_test.py b/tests/ranger_test.py index b002834..f3487cc 100644 --- a/tests/ranger_test.py +++ b/tests/ranger_test.py @@ -8,7 +8,7 @@ from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype, cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype) + gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation) # The reference Ranger adds epsilon before debiasing V while Optimi debases before. # Ranger tests with a smaller epsilon then other optimizers to prevent numerical divergances. @@ -28,6 +28,7 @@ cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] @pytest.mark.cpu +@pytest.mark.ranger @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names) def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer) @@ -38,6 +39,7 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.ranger @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): # test ranger longer due to the lookahead step @@ -49,8 +51,17 @@ def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, f cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.ranger @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): # test ranger longer due to the lookahead step gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), - iterations=160, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file + iterations=160, framework_opt_step=torch.rand(1).item() > 0.5) + + +@pytest.mark.cuda +@pytest.mark.ranger +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) +def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): + optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), + framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file diff --git a/tests/sgd_test.py b/tests/sgd_test.py index ece1e52..2500486 100644 --- a/tests/sgd_test.py +++ b/tests/sgd_test.py @@ -8,7 +8,7 @@ from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype, cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype) + gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation) @@ -37,6 +37,7 @@ cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] @pytest.mark.cpu +@pytest.mark.sgd @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names) def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer) @@ -47,6 +48,7 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.sgd @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=80) @@ -57,7 +59,19 @@ def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, f cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.sgd @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), - iterations=80, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file + framework_opt_step=torch.rand(1).item() > 0.5) + + +@pytest.mark.cuda +@pytest.mark.sgd +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) +def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): + if optim_name in ["sgd", "sgd_l2"]: + pytest.skip("Skip tests for SGD and SGD with L2 weight decay.") + # SGD will error out more often if iterations is the default of 80 + optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), + iterations=20, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file diff --git a/tests/stableadam_test.py b/tests/stableadam_test.py index 7d682c4..86c7040 100644 --- a/tests/stableadam_test.py +++ b/tests/stableadam_test.py @@ -8,7 +8,7 @@ from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_gtype, cpu_ftype, cuda_dim1, cuda_dim2, cuda_gtype, cuda_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype) + gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation) @@ -31,6 +31,7 @@ cpu_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] @pytest.mark.cpu +@pytest.mark.stableadam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cpu_values, ids=cpu_names) def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cpu'), buffer) @@ -41,16 +42,27 @@ def test_optimizer_cpu(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ft cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.stableadam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_optimizer_cuda(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): run_optimizer(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), buffer, iterations=80) + cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) cuda_names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] @pytest.mark.cuda +@pytest.mark.stableadam @pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) def test_gradient_release(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): gradient_release(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), - iterations=80, framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file + framework_opt_step=torch.rand(1).item() > 0.5) + + +@pytest.mark.cuda +@pytest.mark.stableadam +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name, ftype", cuda_values, ids=cuda_names) +def test_optimizer_accumulation(dim1:int, dim2:int, gtype:torch.dtype, optim_name:str, ftype:str): + optimizer_accumulation(optimizers, dim1, dim2, gtype, optim_name, ftype, torch.device('cuda'), + framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file