You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to run Ring Attention on a machine with 6 A100 GPUs, and I'm finding that when I try to set the sequence parallelism dimension to anything other than a power of 2, the process crashes with a JAX partitioning error.
I'd be grateful for any insight into whether or not I'm doing something wrong in the way I'm invoking the training script, and for any advice on how to work around this issue.
Steps to Reproduce
Consider the following script for invoking llamabpt.train:
For Configuration 1, where the sequence parallelism dimension is 4, the training script runs as expected without errors.
However, when I uncomment Configuration 2, where the sequence parallelism dimension is 3, the training script crashes with the following error:
ValueError: One of pjit outputs with pytree key path .params['params']['lm_head']['kernel'] was given the sharding of NamedSharding(mesh={'dp': 1, 'fsdp': 1, 'tp': 1, 'sp': 3}, spec=PartitionSpec(('fsdp', 'sp'), 'tp')), which implies that the global size of its dimension 0 should be divisible by 3, but it is equal to 2048 (full shape: (2048, 32000))
I'm trying to run Ring Attention on a machine with 6 A100 GPUs, and I'm finding that when I try to set the sequence parallelism dimension to anything other than a power of 2, the process crashes with a JAX partitioning error.
I'd be grateful for any insight into whether or not I'm doing something wrong in the way I'm invoking the training script, and for any advice on how to work around this issue.
Steps to Reproduce
Consider the following script for invoking
llamabpt.train
:For Configuration 1, where the sequence parallelism dimension is
4
, the training script runs as expected without errors.However, when I uncomment Configuration 2, where the sequence parallelism dimension is
3
, the training script crashes with the following error:The error occurs during the first call to
sharded_init_fn
.I would expect Configuration 2 to run successfully, because the total sequence length (
98304
) is a multiple of the sequence parallelism dimension (3
).Generalizing to more sequence parallelism dimensions, I find that:
SEQ_PAR_DIM
to either2
or4
runs successfully.SEQ_PAR_DIM
to either3
or6
crashes with a partitioning error.The text was updated successfully, but these errors were encountered: