Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP+LORA on multiple gpu(A100 80gb*4) ValueError: Cannot flatten integer dtype tensors #2250

Open
6 of 8 tasks
Paxwell-Paxwell opened this issue Jan 10, 2025 · 2 comments
Open
6 of 8 tasks
Labels
bug Something isn't working

Comments

@Paxwell-Paxwell
Copy link

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

The LoRA configuration should work with fsdp

Current behaviour

[rank0]: raise ValueError("Cannot flatten integer dtype tensors")
[rank0]: ValueError: Cannot flatten integer dtype tensors
[rank1]: Traceback (most recent call last):

Steps to reproduce

Run Axolotl on multiple GPUs using LoRA+ FSDP, 4 NVIDIA A100 GPUs with 80GB
-torch Version: 2.5.1
-axolotl Version: 0.6.0

Config yaml

base_model: meta-llama/Llama-3.3-70B-Instruct
model_type: LlamaForCausalLM
processing_class: AutoTokenizer

plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true

load_in_8bit: true
load_in_4bit: false
strict: false

chat_template: llama3
datasets:

dataset_prepared_path: ./workspace/aiLawData/last_run_prepared
val_set_size: 0
output_dir: ./workspace/aiLawData/outputs/Llama-3.3-70B-memo-law-Instruct-lora-r256-v3
hub_model_id: PaxwellPaxwell/Llama-3.3-70B-Memo-law-Instruct-adapter-lora-r256-v3
sequence_len: 12000
sample_packing: false
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 256
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

wandb_project: Ai-Law
wandb_entity:
wandb_watch:
wandb_name: Llama-3.3-70B-memo-law-Instruct-adapter-lora-r256-v3
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 10
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
auto_resume_from_checkpoints: true

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing_kwargs:
  use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 100
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 2
debug:
deepspeed:
weight_decay: 0.0

fsdp:
  - full_shard
  - auto_wrap

fsdp_config:
  activation_checkpointing: true
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: true
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD

special_tokens:
  pad_token: <|finetune_right_pad_id|>
  eos_token: <|eot_id|>

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10.13

axolotl branch-commit

main/latest

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@Paxwell-Paxwell Paxwell-Paxwell added the bug Something isn't working label Jan 10, 2025
@Paxwell-Paxwell Paxwell-Paxwell changed the title ValueError: Cannot flatten integer dtype tensors (help please) on multiple gpu(A100 80gb*4) FSDP+LORA ValueError: Cannot flatten integer dtype tensors (help please) on multiple gpu(A100 80gb*4) Jan 10, 2025
@Paxwell-Paxwell Paxwell-Paxwell changed the title FSDP+LORA ValueError: Cannot flatten integer dtype tensors (help please) on multiple gpu(A100 80gb*4) FSDP+LORA on multiple gpu(A100 80gb*4) ValueError: Cannot flatten integer dtype tensors Jan 10, 2025
@NJordan72
Copy link
Contributor

Can you post the stack trace so we can see what is throwing the error? I had a similar problem earlier this week and depending on where it is coming from I either had to set gradient_accumulation_steps to 1 or turn off the liger cross entropy kernel

@Paxwell-Paxwell
Copy link
Author

Paxwell-Paxwell commented Jan 10, 2025

Can you post the stack trace so we can see what is throwing the error? I had a similar problem earlier this week and depending on where it is coming from I either had to set gradient_accumulation_steps to 1 or turn off the liger cross entropy kernel

super().init(*_args, **kwargs)
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank2]: return _run_code(code, main_globals, None,
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/runpy.py", line 86, in _run_code
[rank2]: exec(code, run_globals)
[rank2]: File "/home/b6410500220/desktop/LLM_FINETUNE/axolotl/src/axolotl/cli/train.py", line 58, in
[rank2]: fire.Fire(do_cli)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/fire/core.py", line 135, in Fire
[rank2]: component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire
[rank2]: component, remaining_args = _CallAndUpdateTrace(
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank2]: component = fn(*varargs, **kwargs)
[rank2]: File "/home/b6410500220/desktop/LLM_FINETUNE/axolotl/src/axolotl/cli/train.py", line 34, in do_cli
[rank2]: return do_train(parsed_cfg, parsed_cli_args)
[rank2]: File "/home/b6410500220/desktop/LLM_FINETUNE/axolotl/src/axolotl/cli/train.py", line 47, in do_train
[rank2]: model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
[rank2]: File "/home/b6410500220/desktop/LLM_FINETUNE/axolotl/src/axolotl/train.py", line 202, in train
[rank2]: trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/transformers/trainer.py", line 2155, in train
[rank2]: return inner_training_loop(
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/transformers/trainer.py", line 2313, in _inner_training_loop
[rank2]: self.model = self.accelerator.prepare(self.model)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1339, in prepare
[rank2]: result = tuple(
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1340, in
[rank2]: self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1215, in _prepare_one
[rank2]: return self.prepare_model(obj, device_placement=device_placement)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1512, in prepare_model
[rank2]: model = FSDP(model, **kwargs)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in init
[rank2]: _auto_wrap(
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank2]: _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 545, in _recursive_wrap
[rank2]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 545, in _recursive_wrap
[rank2]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 545, in _recursive_wrap
[rank2]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank2]: [Previous line repeated 2 more times]
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 563, in _recursive_wrap
[rank2]: return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 492, in _wrap
[rank2]: return wrapper_cls(module, **kwargs)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in init
[rank2]: _init_param_handle_from_module(
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 636, in _init_param_handle_from_module
[rank2]: _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 648, in _init_param_handle_from_params
[rank2]: handle = FlatParamHandle(
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 584, in init
[rank2]: self._init_flat_param_and_metadata(
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 634, in _init_flat_param_and_metadata
[rank2]: ) = self._validate_tensors_to_flatten(params)
[rank2]: File "/home/b6410500220/.conda/envs/llmenv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank2]: raise ValueError("Cannot flatten integer dtype tensors")
[rank2]: ValueError: Cannot flatten integer dtype tensors
[rank3]: Traceback (most recent call last):

I try your solution but still same issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants