Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading a universal checkpoint #5263

Merged
merged 6 commits into from
Mar 13, 2024
Merged

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Mar 12, 2024

This PR fixes the following two points regarding checkpoint loading.

  • Load optimizer states
    With this PR, we removed optimizer's step() on initialization. This made the DS's parameter update match with PyTorch's normal behavior. However, we don't have keys in optimizer states any more when we load a checkpoint.
    For legacy/elastic checkpoints, the PR changed the checkpoint loaders to create keys and buffers on loading. However, the loader for universal checkpoints still relies on keys in optimizer states. As the result, loading a universal checkpoint fails.
    This PR fixes the loader to find optimizer state keys from a given checkpoint.

  • Resume step count 2943e6a
    The checkpoint loader for a universal checkpoint resumes step count for optimizer only when the param group already has step. But some optimizers creates the key step in a param group at the first call of step() (e.g. Apex Fused Adam. In this case, the step count is not restored. This PR changes this behavior to always set step count in a param group.
    This PR also stop incrementing the step count when loading. I didn't see why we need to increment the step count for my small example, but we may need a discussion to consider various cases.

@@ -2785,7 +2785,7 @@ def load_checkpoint(self,
if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)
self.update_optimizer_step(step=client_states['iteration'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mosheisland, FYI. Based on #4588, is there a potential off-by-one issue here?

@tohtana tohtana marked this pull request as ready for review March 12, 2024 23:50
@tohtana tohtana added this pull request to the merge queue Mar 13, 2024
@tohtana tohtana removed this pull request from the merge queue due to a manual request Mar 13, 2024
@tjruwase tjruwase added this pull request to the merge queue Mar 13, 2024
Merged via the queue into master with commit b112c99 Mar 13, 2024
12 checks passed
github-merge-queue bot pushed a commit that referenced this pull request Mar 28, 2024
This PR includes the following improvement regarding universal
checkpoint.

- Restoring step

A universal checkpoint saves the training step count taken from the
engine. In
#5263, we fixed to always set
this count to restore training step count to optimizer's states
per-param (`optimizer_state['state`][param]['step']`) and a param_group.
However, this approach does not restore the optimizer's state and param
groups precisely due to different behaviors of optimizers.

Torch's Adam doesn't make `step` in a param groups and only uses
`optimizer_state['state'][param]['step']`. Apex's fused adam only uses
`step` in a param groups. DeepSpeed's fused adam creates `step` in a
param groups and never updates. It only uses
`optimizer_state['state'][param]['step']`.
Consequently, this leads to discrepancies between the restored and
original states of the optimizer and param groups.

This PR modifies the restoration process to ensure that the step number
in the optimizer's state and param groups matches those in the original
setup, effectively aligning the restored and original optimizer states
and param groups.

- Unit tests of DP size scaling

This PR also adds unit tests to verify universal checkpointing. They run
training with DP, save a checkpoint, and converts in to a universal
checkpoint. Then they load the checkpoint with a different DP size and
validate that parameters and the all-gathered (ZeRO 1/2) optimizer
states match.

- Fix bug of loading with `load_optimizer_states=False`

The loader doesn't load parameters from a universal checkpoint when
`load_optimizer_states=False`.
c8c0498
fixes this issue.
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
This PR fixes the following two points regarding checkpoint loading.

- Load optimizer states
With [this PR](microsoft#5104), we
removed optimizer's `step()` on initialization. This made the DS's
parameter update match with PyTorch's normal behavior. However, we don't
have keys in optimizer states any more when we load a checkpoint.
For legacy/elastic checkpoints, the PR changed the checkpoint loaders to
create keys and buffers on loading. However, the loader for universal
checkpoints still relies on keys in optimizer states. As the result,
loading a universal checkpoint fails.
This PR fixes the loader to find optimizer state keys from a given
checkpoint.

- Resume step count
microsoft@2943e6a
The checkpoint loader for a universal checkpoint resumes step count for
optimizer only when the param group already has `step`. But some
optimizers creates the key `step` in a param group at the first call of
`step()` (e.g. Apex [Fused
Adam](https://github.com/NVIDIA/apex/blob/810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c/apex/optimizers/fused_adam.py#L154).
In this case, the step count is not restored. This PR changes this
behavior to always set step count in a param group.
This PR also stop incrementing the step count when loading. I didn't see
why we need to increment the step count for my small example, but we may
need a discussion to consider various cases.
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
This PR includes the following improvement regarding universal
checkpoint.

- Restoring step

A universal checkpoint saves the training step count taken from the
engine. In
microsoft#5263, we fixed to always set
this count to restore training step count to optimizer's states
per-param (`optimizer_state['state`][param]['step']`) and a param_group.
However, this approach does not restore the optimizer's state and param
groups precisely due to different behaviors of optimizers.

Torch's Adam doesn't make `step` in a param groups and only uses
`optimizer_state['state'][param]['step']`. Apex's fused adam only uses
`step` in a param groups. DeepSpeed's fused adam creates `step` in a
param groups and never updates. It only uses
`optimizer_state['state'][param]['step']`.
Consequently, this leads to discrepancies between the restored and
original states of the optimizer and param groups.

This PR modifies the restoration process to ensure that the step number
in the optimizer's state and param groups matches those in the original
setup, effectively aligning the restored and original optimizer states
and param groups.

- Unit tests of DP size scaling

This PR also adds unit tests to verify universal checkpointing. They run
training with DP, save a checkpoint, and converts in to a universal
checkpoint. Then they load the checkpoint with a different DP size and
validate that parameters and the all-gathered (ZeRO 1/2) optimizer
states match.

- Fix bug of loading with `load_optimizer_states=False`

The loader doesn't load parameters from a universal checkpoint when
`load_optimizer_states=False`.
microsoft@c8c0498
fixes this issue.
dbyoung18 pushed a commit to dbyoung18/DeepSpeed that referenced this pull request Jun 11, 2024
This PR fixes the following two points regarding checkpoint loading.

- Load optimizer states
With [this PR](microsoft#5104), we
removed optimizer's `step()` on initialization. This made the DS's
parameter update match with PyTorch's normal behavior. However, we don't
have keys in optimizer states any more when we load a checkpoint.
For legacy/elastic checkpoints, the PR changed the checkpoint loaders to
create keys and buffers on loading. However, the loader for universal
checkpoints still relies on keys in optimizer states. As the result,
loading a universal checkpoint fails.
This PR fixes the loader to find optimizer state keys from a given
checkpoint.

- Resume step count
microsoft@2943e6a
The checkpoint loader for a universal checkpoint resumes step count for
optimizer only when the param group already has `step`. But some
optimizers creates the key `step` in a param group at the first call of
`step()` (e.g. Apex [Fused
Adam](https://github.com/NVIDIA/apex/blob/810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c/apex/optimizers/fused_adam.py#L154).
In this case, the step count is not restored. This PR changes this
behavior to always set step count in a param group.
This PR also stop incrementing the step count when loading. I didn't see
why we need to increment the step count for my small example, but we may
need a discussion to consider various cases.
dbyoung18 pushed a commit to dbyoung18/DeepSpeed that referenced this pull request Jun 11, 2024
This PR includes the following improvement regarding universal
checkpoint.

- Restoring step

A universal checkpoint saves the training step count taken from the
engine. In
microsoft#5263, we fixed to always set
this count to restore training step count to optimizer's states
per-param (`optimizer_state['state`][param]['step']`) and a param_group.
However, this approach does not restore the optimizer's state and param
groups precisely due to different behaviors of optimizers.

Torch's Adam doesn't make `step` in a param groups and only uses
`optimizer_state['state'][param]['step']`. Apex's fused adam only uses
`step` in a param groups. DeepSpeed's fused adam creates `step` in a
param groups and never updates. It only uses
`optimizer_state['state'][param]['step']`.
Consequently, this leads to discrepancies between the restored and
original states of the optimizer and param groups.

This PR modifies the restoration process to ensure that the step number
in the optimizer's state and param groups matches those in the original
setup, effectively aligning the restored and original optimizer states
and param groups.

- Unit tests of DP size scaling

This PR also adds unit tests to verify universal checkpointing. They run
training with DP, save a checkpoint, and converts in to a universal
checkpoint. Then they load the checkpoint with a different DP size and
validate that parameters and the all-gathered (ZeRO 1/2) optimizer
states match.

- Fix bug of loading with `load_optimizer_states=False`

The loader doesn't load parameters from a universal checkpoint when
`load_optimizer_states=False`.
microsoft@c8c0498
fixes this issue.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants