Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Add coat for float8 optimizer #1231

Draft
wants to merge 55 commits into
base: main
Choose a base branch
from

Conversation

MirMustafaAli
Copy link

@MirMustafaAli MirMustafaAli commented Nov 6, 2024

This is a Work in Progress PR for #1190.

As a draft PR, I have followed the first piece of advice by @gau-nernst of "extending OptimStateFp8". Have created a separate Dynamic Range Function Instead of creating a different quantize_fp8 method as it will be applied before quantization to achieve larger representation range of float8 datatypes and the class will be storing value k to inverse the it after dequantization.

Requirements:
TBA
Additional Code/logic Added:
TBA
Logic/Code changes to existing codebase:
TBA
Outcome:
TBA
Scope of Usage:
TBA
Example
TBA

Changes:

  • Dynamic Range Expansion Function: implementation of formula from the paper
  • Created OptimStateFp8WithDynamicRangeExpansion class by extending OptimStateFp8: by referencing the implementation of the OptimStatefp8. I have only overridden the dequantize method
  • Implemented aten.copy.default and aten.to_copy.default for OptimStateFp8WithDynamicRangeExpansion:

Benchmarks

Parameters

Parameter Value
Learning Rate (lr) 0.0001
Automatic Mixed Precision (amp) bf16
Seed 42
Model timm/vit_base_patch16_224.augreg_in21k
Optimizer (optim) AdamWFp8Ao_coat
Compile False
Profile False
Project COAT-benchmarking
Number of Epochs 10
Run Name AdamWFp8Ao_coat
Full BF16 False
Number of Workers 4
Batch Size 1024
Weight Decay 0
Channels Last False
Optimizer CPU Offload None
Cosine LR Scheduler False
Checkpoint Activations False

Results

W B Chart Nov 15 2024 (1)

W B Chart Nov 15 2024

W B Chart Nov 14 2024

Copy link

pytorch-bot bot commented Nov 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1231

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 6, 2024
@MirMustafaAli MirMustafaAli marked this pull request as draft November 6, 2024 11:35
@gau-nernst
Copy link
Collaborator

I was thinking you can just add a flag to the current OptimStateFp8, something like dynamic_range_expansion: bool, instead of subclass-ing it.

@MirMustafaAli
Copy link
Author

I was thinking you can just add a flag to the current OptimStateFp8, something like dynamic_range_expansion: bool, instead of subclass-ing it.

i have added the flag for optimstatefp8. could you verify its right?

@gau-nernst
Copy link
Collaborator

I think this requires a bit more work. You need to verify that you can create an optimizer with this (add test to https://github.com/pytorch/ao/blob/main/test/prototype/test_low_bit_optim.py) as well do some short training runs for sanity checks (using https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_low_bit_adam.py).

I think for merging the PR, we should wait for the official code release to check numeric against them.

If you don't mind, we can discuss more details in GPU-MODE discord group https://discord.gg/gpumode. Just create a thread under torchao and tag me in (@gau.nernst)

@MirMustafaAli
Copy link
Author

MirMustafaAli commented Nov 6, 2024

I understand the situation for merging the PR. Will be glad to work on working on this issue. creating thread in gpumode

torchao/prototype/low_bit_optim/subclass_fp8.py Outdated Show resolved Hide resolved
torchao/prototype/low_bit_optim/subclass_fp8.py Outdated Show resolved Hide resolved
torchao/prototype/low_bit_optim/subclass_fp8.py Outdated Show resolved Hide resolved
@gau-nernst gau-nernst added the topic: new feature Use this tag if this PR adds a new feature label Nov 13, 2024
Copy link
Collaborator

@gau-nernst gau-nernst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. The PR is coming out nicely. There are some failing CI tests. Can you fix them, including the ruff linter?

Some extra items once that is finished:

  • Update doc (link to the paper + usage)
  • Run benchmark for sanity check https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_low_bit_adam.py. I'm thinking comparing between BF16 baseline, FP8 optimizer, and FP8 COAT optimizer. Feel free to select a benchmark config suitable for you. And add the benchmark results in this PR description. Ideally, it should show that FP8 COAT is better than FP8 (though we might not observe it)

test/prototype/test_low_bit_optim.py Outdated Show resolved Hide resolved
test/prototype/test_low_bit_optim.py Outdated Show resolved Hide resolved
torchao/prototype/low_bit_optim/adam.py Show resolved Hide resolved
torchao/prototype/low_bit_optim/subclass_fp8.py Outdated Show resolved Hide resolved
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Show a8wxdq load error only when the quant is used

* Update Error check
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants