Skip to content

Commit

Permalink
add P2P env when multi-gpu but not the full node (#2041)
Browse files Browse the repository at this point in the history
Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
winglian and winglian authored Nov 12, 2024
1 parent 9f1cf9b commit ad435a3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/axolotl/utils/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
utils to get GPU info for the current environment
"""
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
)
from accelerate.utils.environment import get_gpu_info


def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
unsupported_devices = {"RTX 6000 Ada"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8:
if any(
device_name in unsupported_device
for device_name in device_names
for unsupported_device in unsupported_devices
):
return False
except Exception: # pylint: disable=broad-except # nosec
pass
return True
4 changes: 4 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths

LOG = get_logger("axolotl")
Expand Down Expand Up @@ -461,6 +462,9 @@ def setup_fsdp_envs(cfg):


def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
os.environ["NCCL_P2P_DISABLE"] = "1"
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
Expand Down

0 comments on commit ad435a3

Please sign in to comment.