-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question: How to use Float8InferenceLinear with FSDP1/2? #704
Comments
Unfortunately the Float8InferenceLinear is being developed against the latest pytorch nightly and is not very tested on older versions of PyTorch. If it is possible for you to update your PyTorch version that is recommend. If the problem still persists after updating and you are able to create a minimal reproducer we can look into this. |
@drisspg Got it. Thank you! To confirm, |
Today this would indeed install 2.5 dev for today's date but yeah generally for any feature leveraging torch.compile you want to either be on the latest stable (today this is 2.4) or use nightlies |
Thanks @msaroufim ! Is cu118 version also supported and tested? (if I disable torch compile and fsdp2 dtensor and just use fsdp1) let me do a quick test and check. Thank you! |
A quick update: it turns out that there might but some issues with
exploring other options now and probably have to use 12.1 runtime version instead Update: seems root cause are libnvrtc.so.11.2 loading issues for 11.8: Could not load library libnvrtc.so.11.2. Error: libnvrtc.so.11.2: cannot open shared object file: No such file or directory for 12.1: |
I'd try isolating things in a fresh conda environment, also if you're mucking around with CUDA versions keep in mind that torchao binaries on pypi are using cuda 12.1 so would recommend installing ao from source or downloading it from the pytorch index |
Thank you! Resolved the above issue by adding the current path to |
Faced the same issue when testing the mixtral 8X7B model (gated routing layer has been excluded) with the code of replacing layers + FSDP below:
Also tried uncomment the line
|
@qingquansong thanks! do you have a minimal repro so we can take a look? |
Let me create a mini mixtral model with some synthetic data. |
Update:
Some other thing I'm not sure is if I just wanna do inference how should I set the following 6 args + the FSDP args? The speed seems to slow down with the FP8 layer in this case and memory is also not reduced much as expected. Setting input config to DYNAMIC seems to make things faster but still comparable with bf16 for mixtral 8*7B
**save this script in
|
For the speed / memory issue, I guess it related to not using torch compile based on the related tickets: #685 [FP8] performance degradation in speed and memory without compile I'll check if I can use torch compile here. Thanks. |
Currently it's a bit blocked on the torch compile + Mixtral. [The context of using torch.compile is that it seems it's required to combine with fp8 linear to help improve the speed as discussed in some threads: pytorch/torchtitan#462 (comment) Huggingface Mixtral model does not directly support torch compile as stated here mainly due to the sparse moe with torch where causing the dynamic token numbers in routing to different experts and is also an ongoing efforts here: huggingface/transformers#30793 I've tried the option in gpu-fast (similar as the above pr change to convert to a fused moe) but it's more suitable for fast text generation phase with small batch size and would have high memory consumption for large batch size prefill stage. Also could break the nature of linear layers to replace with fp8linear directly. |
I put some of my raw test scripts here https://github.com/qingquansong/fp8_fsdp_test in case anyone is interested. Sorry that didn't change the model and data local paths. |
Hey Team,
I'm trying to use FSDP1/2 with Float8InferenceLinear but seems have some issues (with torch 2.3.1+cu118). Do you suggestion to bump to higher version of torch and have a try or maybe use the training setup without using the inference layer? I also tried using the Flont8linear layer without using the quantization function to convert to Float8InferenceLinear but seems face some issues when using FSDP1 that when computing the amax, some input x tensors are empty (x.numel()=0) and some are NaN.
Best regards,
QQ
The text was updated successfully, but these errors were encountered: