From c1d5410026b0c98b60c798440b0b47f63b26e5c2 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 20 Feb 2024 22:24:05 -0600 Subject: [PATCH] add examples to documentation --- docs/foreach.md | 31 +++++++++++++++++++++++++++- docs/fully_decoupled_weight_decay.md | 22 ++++++++++++++++++++ docs/kahan_summation.md | 27 ++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/docs/foreach.md b/docs/foreach.md index fcbd9cb..f93aa36 100644 --- a/docs/foreach.md +++ b/docs/foreach.md @@ -8,4 +8,33 @@ Like PyTorch, optimi supports foreach implementations of all optimizers. Foreach Foreach implementations can increase optimizer peak memory usage. optimi attempts to reduce this extra overhead by reusing the gradient buffer for temporary variables. If the gradients are required between the optimization step and [gradient reset step](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html#torch.optim.Optimizer.zero_grad), set `foreach=False` to use the for-loop implementation. -If unspecified `foreach=None`, optimi will use the foreach implementation if training on a Cuda device. \ No newline at end of file +??? warning "Important: Foreach Requires PyTorch 2.1+" + + optimi’s foreach implementations require PyTorch 2.1 or newer. + +If unspecified `foreach=None`, optimi will use the foreach implementation if training on a Cuda device. + +## Example + +Using a foreach implementation is as simple as calling + +```python +import torch +from torch import nn +from optimi import AdamW + +# create model +model = nn.Linear(20, 1, device='cuda') + +# initialize any optmi optimizer with `foreach=True` +# models on a cuda device will default to `foreach=True` +opt = AdamW(model.parameters(), lr=1e-3, foreach=True) + +# forward and backward +loss = model(torch.randn(20)) +loss.backward() + +# optimizer step is the foreach implementation +opt.step() +opt.zero_grad() +``` \ No newline at end of file diff --git a/docs/fully_decoupled_weight_decay.md b/docs/fully_decoupled_weight_decay.md index 98500e9..b9f99b1 100644 --- a/docs/fully_decoupled_weight_decay.md +++ b/docs/fully_decoupled_weight_decay.md @@ -28,6 +28,28 @@ For example, to match [AdamW’s](optimizers/adamw.md) default decoupled weight By default, optimi optimizers assume `lr` is the maximum scheduled learning rate. This allows the applied weight decay $(\gamma_t/\gamma_\text{max})\lambda\bm{\theta}_{t-1}$ to match the learning rate schedule. Set `max_lr` if this is not the case. +## Example + +```python +import torch +from torch import nn +from optimi import AdamW + +# create model +model = nn.Linear(20, 1, dtype=torch.bfloat16) + +# initialize any optimi optimizer useing `decouple_lr=True` to enable fully +# decoupled weight decay. note `weight_decay` is lower then the default of 1e-2 +opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5, decouple_lr=True) + +# model is optimized using fully decoupled weight decay +loss = model(torch.randn(20, dtype=torch.bfloat16)) +loss.backward() + +opt.step() +opt.zero_grad() +``` + ## Algorithm The algorithm below shows the difference between PyTorch’s AdamW and optimi’s Adam with fully decoupled weight decay. diff --git a/docs/kahan_summation.md b/docs/kahan_summation.md index d51a9a4..6937a76 100644 --- a/docs/kahan_summation.md +++ b/docs/kahan_summation.md @@ -57,6 +57,33 @@ Calculating the total memory savings depends on [activations and batch size](htt Training in BFloat16 instead of mixed precision results in a ~10% speedup on a single GPU at the same batch size. BFloat16 training can further increase distributed training speed due to the halved bandwidth cost. +## Example + +Using Kahan summation with an optimi optimizer only requires a casting a model and optionally input into low precision (BFloat16 or Float16). Since Kahan summation is applied layer by layer, it works for models with standard and low precision weights. + +```python +import torch +from torch import nn +from optimi import AdamW + +# create or cast some model layers in low precision (bfloat16) +model = nn.Linear(20, 1, dtype=torch.bfloat16) + +# initialize any optmi optimizer with low precsion parameters +# Kahan summation is enabled since some model layers are bfloat16 +opt = AdamW(model.parameters(), lr=1e-3) + +# forward and backward, casting input to bfloat16 if needed +loss = model(torch.randn(20, dtype=torch.bfloat16)) +loss.backward() + +# optimizer step automatically uses Kahan summation for low precision layers +opt.step() +opt.zero_grad() +``` + +To disable Kahan Summation pass `kahan_summation=False` on optimizer initialiation. + ## Algorithm SGD with Kahan summation.