Skip to content

Commit

Permalink
add examples to documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Feb 21, 2024
1 parent b31c6b7 commit c1d5410
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
31 changes: 30 additions & 1 deletion docs/foreach.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,33 @@ Like PyTorch, optimi supports foreach implementations of all optimizers. Foreach

Foreach implementations can increase optimizer peak memory usage. optimi attempts to reduce this extra overhead by reusing the gradient buffer for temporary variables. If the gradients are required between the optimization step and [gradient reset step](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html#torch.optim.Optimizer.zero_grad), set `foreach=False` to use the for-loop implementation.

If unspecified `foreach=None`, optimi will use the foreach implementation if training on a Cuda device.
??? warning "Important: Foreach Requires PyTorch 2.1+"

optimi’s foreach implementations require PyTorch 2.1 or newer.

If unspecified `foreach=None`, optimi will use the foreach implementation if training on a Cuda device.

## Example

Using a foreach implementation is as simple as calling

```python
import torch
from torch import nn
from optimi import AdamW

# create model
model = nn.Linear(20, 1, device='cuda')

# initialize any optmi optimizer with `foreach=True`
# models on a cuda device will default to `foreach=True`
opt = AdamW(model.parameters(), lr=1e-3, foreach=True)

# forward and backward
loss = model(torch.randn(20))
loss.backward()

# optimizer step is the foreach implementation
opt.step()
opt.zero_grad()
```
22 changes: 22 additions & 0 deletions docs/fully_decoupled_weight_decay.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ For example, to match [AdamW’s](optimizers/adamw.md) default decoupled weight

By default, optimi optimizers assume `lr` is the maximum scheduled learning rate. This allows the applied weight decay $(\gamma_t/\gamma_\text{max})\lambda\bm{\theta}_{t-1}$ to match the learning rate schedule. Set `max_lr` if this is not the case.

## Example

```python
import torch
from torch import nn
from optimi import AdamW

# create model
model = nn.Linear(20, 1, dtype=torch.bfloat16)

# initialize any optimi optimizer useing `decouple_lr=True` to enable fully
# decoupled weight decay. note `weight_decay` is lower then the default of 1e-2
opt = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5, decouple_lr=True)

# model is optimized using fully decoupled weight decay
loss = model(torch.randn(20, dtype=torch.bfloat16))
loss.backward()

opt.step()
opt.zero_grad()
```

## Algorithm

The algorithm below shows the difference between PyTorch’s AdamW and optimi’s Adam with fully decoupled weight decay.
Expand Down
27 changes: 27 additions & 0 deletions docs/kahan_summation.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,33 @@ Calculating the total memory savings depends on [activations and batch size](htt

Training in BFloat16 instead of mixed precision results in a ~10% speedup on a single GPU at the same batch size. BFloat16 training can further increase distributed training speed due to the halved bandwidth cost.

## Example

Using Kahan summation with an optimi optimizer only requires a casting a model and optionally input into low precision (BFloat16 or Float16). Since Kahan summation is applied layer by layer, it works for models with standard and low precision weights.

```python
import torch
from torch import nn
from optimi import AdamW

# create or cast some model layers in low precision (bfloat16)
model = nn.Linear(20, 1, dtype=torch.bfloat16)

# initialize any optmi optimizer with low precsion parameters
# Kahan summation is enabled since some model layers are bfloat16
opt = AdamW(model.parameters(), lr=1e-3)

# forward and backward, casting input to bfloat16 if needed
loss = model(torch.randn(20, dtype=torch.bfloat16))
loss.backward()

# optimizer step automatically uses Kahan summation for low precision layers
opt.step()
opt.zero_grad()
```

To disable Kahan Summation pass `kahan_summation=False` on optimizer initialiation.

## Algorithm

SGD with Kahan summation.
Expand Down

0 comments on commit c1d5410

Please sign in to comment.