Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training config that worked with transformers v4.4.6.3 results in OOM error with v4.47.0 (using SFTTrainer) #35108

Open
2 of 4 tasks
jjbuck opened this issue Dec 5, 2024 · 5 comments
Labels

Comments

@jjbuck
Copy link

jjbuck commented Dec 5, 2024

System Info

- `transformers` version: 4.47.0
- Platform: Linux-6.8.0-1015-aws-x86_64-with-glibc2.35
- Python version: 3.12.6
- Huggingface_hub version: 0.26.2
- Safetensors version: 0.4.5
- Accelerate version: 1.1.1
- Accelerate config:    not found
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: Yes
- Using GPU in script?: Yes
- GPU type: NVIDIA A100-SXM4-40GB

Who can help?

@ArthurZucker @SunMarc @muellerz

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Training with transformers==4.46.3 runs as expected. Upgrading to transformers==4.47.0 (without changing anything else) leads to an OOM error in the very first training step (see stack trace below).

Run command: accelerate launch --config_file ./accelerate_config.yaml train.py training=path/to/training_config

Accelerate Config

compute_environment: LOCAL_MACHINE                                                                                                                                           
debug: false                                                                                                                                                                 
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
  activation_checkpointing: true
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Training Config

{'accelerator_config': {'dispatch_batches': None,
                        'even_batches': True,
                        'gradient_accumulation_kwargs': None,
                        'non_blocking': False,
                        'split_batches': False,
                        'use_seedable_sampler': True},
 'adafactor': False,
 'adam_beta1': 0.9,
 'adam_beta2': 0.999,
 'adam_epsilon': 1e-08,
 'attn_implementation': 'flash_attention_2',
 'auto_find_batch_size': False,
 'average_tokens_across_devices': False,
 'batch_eval_metrics': False,
 'bf16': 'auto',
 'bf16_full_eval': False,
 'chars_per_token': '<CHARS_PER_TOKEN>',
 'data_seed': None,
 'dataloader_drop_last': False,
 'dataloader_num_workers': 0,
 'dataloader_persistent_workers': False,
 'dataloader_pin_memory': True,
 'dataloader_prefetch_factor': None,
 'dataset_batch_size': 1000,
 'dataset_kwargs': {'skip_prepare_dataset': False},
 'ddp_backend': None,
 'ddp_broadcast_buffers': None,
 'ddp_bucket_cap_mb': None,
 'ddp_find_unused_parameters': None,
 'ddp_timeout': 1800,
 'debug': [],
 'deepspeed': None,
 'delete_ckpts': False,
 'disable_tqdm': False,
 'dispatch_batches': None,
 'do_eval': True,
 'do_predict': False,
 'do_train': False,
 'early_stopping_patience': 10,
 'eval_accumulation_steps': None,
 'eval_delay': 0,
 'eval_do_concat_batches': True,
 'eval_exampleset_info_path': '',
 'eval_exampleset_path': '',
 'eval_on_start': True,
 'eval_packing': False,
 'eval_steps': 10,
 'eval_strategy': 'steps',
 'eval_use_gather_object': False,
 'evaluation_strategy': None,
 'exampleset_info_path': '',
 'exampleset_path': '',
 'force_tokenize_data': False,
 'fp16': False,
 'fp16_backend': 'auto',
 'fp16_full_eval': False,
 'fp16_opt_level': 'O1',
 'fsdp': [],
 'fsdp_config': {'min_num_params': 0,
                 'xla': False,
                 'xla_fsdp_grad_ckpt': False,
                 'xla_fsdp_v2': False},
 'fsdp_min_num_params': 0,
 'fsdp_transformer_layer_cls_to_wrap': None,
 'full_determinism': False,
 'gradient_accumulation_steps': 4,
 'gradient_checkpointing': False,
 'gradient_checkpointing_kwargs': {'use_reentrant': False},
 'greater_is_better': False,
 'group_by_length': False,
 'half_precision_backend': 'auto',
 'hub_always_push': False,
 'hub_model_id': None,
 'hub_private_repo': None,
 'hub_strategy': 'every_save',
 'hub_token': '<HUB_TOKEN>',
 'ignore_data_skip': False,
 'include_for_metrics': [],
 'include_inputs_for_metrics': False,
 'include_num_input_tokens_seen': False,
 'include_tokens_per_second': False,
 'jit_mode_eval': False,
 'label_names': ['labels'],
 'label_smoothing_factor': 0.0,
 'learning_rate': 0.0002,
 'length_column_name': 'length',
 'load_best_model_at_end': True,
 'local_rank': 0,
 'log_level': 'passive',
 'log_level_replica': 'warning',
 'log_on_each_node': True,
 'logging_first_step': False,
 'logging_nan_inf_filter': True,
 'logging_steps': 1,
 'logging_strategy': 'steps',
 'lora_alpha': 32,
 'lora_dropout': 0.05,
 'lora_r': 16,
 'lora_target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj', 'gate_proj'],
 'lr_scheduler_kwargs': {},
 'lr_scheduler_type': 'cosine',
 'mask_instructions': True,
 'max_grad_norm': 1.0,
 'max_seq_length': 1024,
 'max_steps': 100,
 'meta_data': {},
 'metric_for_best_model': 'loss',
 'model_name_or_path': 'Qwen/Qwen2.5-7B-Instruct',
 'mp_parameters': '',
 'neftune_noise_alpha': None,
 'no_cuda': False,
 'num_of_sequences': 1024,
 'num_train_epochs': 3,
 'optim': 'adamw_torch',
 'optim_args': None,
 'optim_target_modules': None,
 'overwrite_output_dir': False,
 'packing': False,
 'past_index': -1,
 'per_device_eval_batch_size': 1,
 'per_device_train_batch_size': 1,
 'per_gpu_eval_batch_size': None,
 'per_gpu_train_batch_size': None,
 'prediction_loss_only': False,
 'push_to_hub': False,
 'push_to_hub_model_id': None,
 'push_to_hub_organization': None,
 'push_to_hub_token': '<PUSH_TO_HUB_TOKEN>',
 'ray_scope': 'last',
 'remove_unused_columns': True,
 'restore_callback_states_from_checkpoint': False,
 'resume_from_checkpoint': None,
 'save_on_each_node': False,
 'save_only_model': False,
 'save_safetensors': True,
 'save_steps': 20,
 'save_strategy': 'steps',
 'save_total_limit': None,
 'seed': 42,
 'skip_memory_metrics': True,
 'smoke_test': False,
 'split_batches': None,
 'tf32': None,
 'torch_compile': False,
 'torch_compile_backend': None,
 'torch_compile_mode': None,
 'torch_dtype': 'bfloat16',
 'torch_empty_cache_steps': None,
 'torchdynamo': None,
 'tpu_metrics_debug': False,
 'tpu_num_cores': None,
 'use_cpu': False,
 'use_ipex': False,
 'use_legacy_prediction_loop': False,
 'use_liger_kernel': False,
 'use_mps_device': False,
 'use_peft': False,
 'val_set_size': 0.0,
 'warmup_ratio': 0.1,
 'warmup_steps': 0,
 'weight_decay': 0.0}

Training script


def main(cfg):
    accelerator = Accelerator()
    model_kwargs = dict(
        attn_implementation=sft_config.attn_implementation,
        torch_dtype=sft_config.torch_dtype,
        use_cache=False,
    )
    model = AutoModelForCausalLM.from_pretrained(sft_config.model_name_or_path, **model_kwargs)
    tokenizer = AutoTokenizer.from_pretrained(sft_config.model_name_or_path, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token
  
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=sft_config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=None,
        dataset_kwargs=sft_config.dataset_kwargs,
    )

    trainer.train()
    trainer.save_model()

if __name__ == "__main__":
    main()

Stack trace

Traceback (most recent call last):
  File "/home/ubuntu/***/train.py", line 233, in main
    trainer.train()
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2164, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2522, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3653, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3709, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 864, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 823, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 811, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1184, in forward
    loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/***/.venv/lib/python3.12/site-packages/transformers/loss/loss_utils.py", line 36, in ForCausalLMLoss
    logits = logits.float()
             ^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.97 GiB. GPU 5 has a total capacity of 39.38 GiB of which 1.53 GiB is free. Including non-PyTorch memory, this process has 37.84 GiB memory in use. Of the allocated memory 35.69 GiB is allocated by PyTorch, and 521.06 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Expected behavior

Training should complete without errors.

@jjbuck jjbuck added the bug label Dec 5, 2024
@jjbuck jjbuck changed the title Training config that worked with v4.4.6.3 results in OOM error with v4.47.0 when using SFTTrainer Training config that worked with v4.4.6.3 results in OOM error with v4.47.0 (using SFTTrainer) Dec 5, 2024
@jjbuck jjbuck changed the title Training config that worked with v4.4.6.3 results in OOM error with v4.47.0 (using SFTTrainer) Training config that worked with transformers v4.4.6.3 results in OOM error with v4.47.0 (using SFTTrainer) Dec 5, 2024
@Rocketknight1
Copy link
Member

cc @muellerzr @SunMarc

@SunMarc
Copy link
Member

SunMarc commented Dec 6, 2024

Hey @jjbuck, thanks for the report. Can you try to find the commit that is causing this with git bisect for example.

@jjbuck
Copy link
Author

jjbuck commented Dec 6, 2024

@SunMarc It looks like it was commit 8b3b9b48fcd6bc06bd9c576f1b09266d577db25 (8b3b9b4).

Among other things, that commit removed or self.is_fsdp_enabled from the list of conditionals in delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled.

If I add that back such that the line reads delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled I no longer run into the OOM error.

That said, it's not clear to me what motivated that change initially, so I don't know whether it's appropriate to simply revert that change.

Copy link

github-actions bot commented Jan 5, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@SunMarc
Copy link
Member

SunMarc commented Jan 6, 2025

Thanks for figuring out this and sorry for the late reply ! We fixed the PR you reported here. LMK if this solves your issue !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants