From ad435a3b09958ff9755a8df6ab695fe7e2831271 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 12 Nov 2024 17:58:26 -0500 Subject: [PATCH] add P2P env when multi-gpu but not the full node (#2041) Co-authored-by: Wing Lian --- src/axolotl/utils/environment.py | 25 +++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 4 ++++ 2 files changed, 29 insertions(+) create mode 100644 src/axolotl/utils/environment.py diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py new file mode 100644 index 0000000000..c447b1ee25 --- /dev/null +++ b/src/axolotl/utils/environment.py @@ -0,0 +1,25 @@ +""" +utils to get GPU info for the current environment +""" +from accelerate.utils.environment import ( + check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, +) +from accelerate.utils.environment import get_gpu_info + + +def check_cuda_p2p_ib_support(): + if not accelerate_check_cuda_p2p_ib_support(): + return False + unsupported_devices = {"RTX 6000 Ada"} + try: + device_names, device_count = get_gpu_info() + if 1 < device_count < 8: + if any( + device_name in unsupported_device + for device_name in device_names + for unsupported_device in unsupported_devices + ): + return False + except Exception: # pylint: disable=broad-except # nosec + pass + return True diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 7ebf384aff..a552905f7c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -17,6 +17,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.utils.distributed import reduce_and_broadcast +from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger("axolotl") @@ -461,6 +462,9 @@ def setup_fsdp_envs(cfg): def prepare_optim_env(cfg): + if not check_cuda_p2p_ib_support(): + if os.getenv("NCCL_P2P_DISABLE") is None: + os.environ["NCCL_P2P_DISABLE"] = "1" if cfg.fsdp: setup_fsdp_envs(cfg) elif cfg.deepspeed: