Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add option for recomputing the casted weight during backwards #186

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jan 13, 2024

Summary

See: #185
For more detail

Disclaimer

Ughh idk, PT2 doesn't let me control what gets recomputed, I am having trouble interpreting the tea leaves

Currently ignore all the performance numbers below, expect for max memory usage. The min-cut-partitioner is actually undoing the recompute for backwards and saving the casted weight tensor. cc @Chillee
See: pytorch/pytorch#117901

Single GPU Linear numbers:

| name      | shape               | ref_dtype      | compiled   | recompute_weight_cast   |   ref_time_sec |   pt_fp8_time_sec |   pt_fp8_speedup |
|:----------|:--------------------|:---------------|:-----------|:------------------------|---------------:|------------------:|-----------------:|
| attn.wqkv | (16384, 8192, 1280) | torch.bfloat16 | True       | True                    |     0.00211272 |        0.00207579 |          1.01779 |
| attn.wqkv | (16384, 8192, 1280) | torch.bfloat16 | True       | False                   |     0.00211962 |        0.00208095 |          1.01858 |
| attn.w0   | (16384, 1024, 8192) | torch.bfloat16 | True       | True                    |     0.00187907 |        0.00149616 |          1.25593 |
| attn.w0   | (16384, 1024, 8192) | torch.bfloat16 | True       | False                   |     0.00187947 |        0.00149665 |          1.25578 |
| ffn.w13   | (16384, 8192, 7168) | torch.bfloat16 | True       | True                    |     0.0102547  |        0.00680098 |          1.50782 |
| ffn.w13   | (16384, 8192, 7168) | torch.bfloat16 | True       | False                   |     0.0102781  |        0.00680872 |          1.50954 |
| ffn.w2    | (16384, 3584, 8192) | torch.bfloat16 | True       | True                    |     0.00538504 |        0.00370726 |          1.45257 |
| ffn.w2    | (16384, 3584, 8192) | torch.bfloat16 | True       | False                   |     0.00539845 |        0.00370568 |          1.4568  |
| attn.wqkv | (16384, 8192, 1280) | torch.float16  | True       | True                    |     0.0021861  |        0.0020997  |          1.04115 |
| attn.wqkv | (16384, 8192, 1280) | torch.float16  | True       | False                   |     0.00217873 |        0.00210146 |          1.03677 |
| attn.w0   | (16384, 1024, 8192) | torch.float16  | True       | True                    |     0.00188072 |        0.00147959 |          1.27111 |
| attn.w0   | (16384, 1024, 8192) | torch.float16  | True       | False                   |     0.00188136 |        0.00148019 |          1.27103 |
| ffn.w13   | (16384, 8192, 7168) | torch.float16  | True       | True                    |     0.0101473  |        0.00671181 |          1.51186 |
| ffn.w13   | (16384, 8192, 7168) | torch.float16  | True       | False                   |     0.0101678  |        0.00670741 |          1.51591 |
| ffn.w2    | (16384, 3584, 8192) | torch.float16  | True       | True                    |     0.00545398 |        0.00362562 |          1.50429 |
| ffn.w2    | (16384, 3584, 8192) | torch.float16  | True       | False                   |     0.00544952 |        0.00362146 |          1.50478 |

FSDP Memory Usage

Verified on single node 8-gpu FSDP that the memory usage is no longer scaling:

Configuration Max Memory Used Before this PR Max Memory Used After this PR
bf16 31.12 GiB 31.12 GiB
dynamic_linear cache casted weight 36.63 GiB 36.06 GiB
dynamic_linear recompute casted weight N/A 29.86 GiB

FSDP Performance

Using single node 8-gpu FSDP setup/compile

Configuration Before this PR It/second After this PR It/second
bf16 2.01 it/s 1.99 it/s
dynamic_linear cache casted weight 2.35 it/s 2.30 it/s
dynamic_linear recompute casted weight N/A 2.30 it/s
delayed_linear cache casted weight 2.15 it/s 2.09 it/s
delayed_linear recompute casted weight N/A 2.08 it/s

Single GPU Memory usage

In eager using this test script: https://gist.github.com/drisspg/75a792f97f5b8fa77f32af7f5280bae5

I am seeing a mac_memory used
Recompute = False:
Max Cuda Memory Used: 1.8438 GiB
Recompute = True:
Max Cuda Memory Used: 1.7032 GiB

A difference of ~0.14 gbs, We would should expect a memory saving of (4096**2)*(1byte)*10(layers) * 1024**-3(bytes per GiB) = 0.15625

Also verified by memory-traces in the gist

Questions

This is kinda a meaty PR that depends on a PyTorch PR(pytorch/pytorch#117667) but I am curious if people have strong feelings on the "UX".

I chose not to make the "recompute weight cast" a config setting instead having it as a module attribute. The swap_linear will set this for every linear it swaps, in theory from_float is granular enough to do this on a per linear basis.

Is there any reason why having it has a global config would be better, (even a global config setting that alters the swap_functions behavior?)

@drisspg drisspg requested a review from vkuzo January 13, 2024 02:16
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 13, 2024
@drisspg drisspg removed the request for review from vkuzo January 13, 2024 02:16
@drisspg drisspg changed the title Add option for recomputing the casted weight during backwards [WIP] Add option for recomputing the casted weight during backwards Jan 13, 2024
@drisspg drisspg marked this pull request as draft January 13, 2024 02:17
@drisspg
Copy link
Contributor Author

drisspg commented Jan 13, 2024

ed(f"call_method {self} {name} {args} {kwargs}")
[2024-01-12 18:38:24,732] [8/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/drisspg/miniconda3/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2024-01-12 18:38:24,732] [8/0] torch._dynamo.variables.higher_order_ops: [ERROR]     raise Unsupported(msg)
[2024-01-12 18:38:24,732] [8/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: call_method GetAttrVariable(TensorVariable(), _data) stride [] {}

You love to see it! Why can't I call shape? avoding calling _data is toughhhhh

@drisspg
Copy link
Contributor Author

drisspg commented Jan 13, 2024

cc @bdhirsh As far as I can till this is erroring because of these calls to the tensor attributes: https://github.com/pytorch-labs/float8_experimental/pull/186/files#diff-00f68398c8aad5a3e946cccd7211a80841da9403d6c664452a45e04101bea6d6R84-R93

I know that in the past anytime we try to access the subclasses attributes outside of the __torch__dispatch code this errors. I don't have any idea how to work around this since I think we need this autograd function and hence can't use the torch_dispatch.

@drisspg drisspg force-pushed the enable_recompute_of_grad branch from 795ebbd to 682c2e8 Compare January 17, 2024 20:55
@drisspg drisspg force-pushed the enable_recompute_of_grad branch from c3f5c9a to 2ffcbe9 Compare January 18, 2024 00:30
@drisspg drisspg requested review from vkuzo, y-sq, bdhirsh and awgu January 18, 2024 02:10
@drisspg drisspg marked this pull request as ready for review January 18, 2024 02:29
@drisspg drisspg changed the title [WIP] Add option for recomputing the casted weight during backwards Add option for recomputing the casted weight during backwards Jan 18, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Jan 18, 2024

I chose not to make the "recompute weight cast" a config setting instead having it as a module attribute. The swap_linear will set this for every linear it swaps, in theory from_float is granular enough to do this on a per linear basis.

The above makes sense to me for this particular setting, if we choose to have a setting. It would be nice to not have a setting at all unless we need it. I feel like FSDP is unusable for real workloads without this, so if the recomputation is fast enough why not just have it as the only path?

@vkuzo
Copy link
Contributor

vkuzo commented Jan 18, 2024

Verified on single node 8-gpu FSDP that the memory usage is no longer scaling:

great! Can we also post throughput metrics on 8-gpu FSDP? If there is a slowdown, having a smaller benchmark to capture + debug it would be useful.

@drisspg drisspg force-pushed the enable_recompute_of_grad branch from 26ce70d to d2da1ad Compare January 20, 2024 00:34
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants