-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BUG] partition_balanced return wrong result. (#4312)
# Background In pipeline parallelism, deepspeed uses `ds_utils.partition_balanced` to balance the partitioning of the model according to the number of parameters or class names. https://github.com/microsoft/DeepSpeed/blob/581e44dd1ab3c409a5905335867c761d5cb4db5b/deepspeed/runtime/pipe/module.py#L380-L395 # What wrong? ``` >>> import deepspeed >>> deepspeed.__version__ '0.10.3+542dc0d5' >>> from deepspeed.runtime import utils as ds_utils >>> ds_utils.partition_balanced([1, 1, 1, 1, 1], 4) [0, 2, 4, 5, 5] >>> ``` the result [0, 2, 4, 5, 5] means [2, 2, 1, 0] layers for each part, which is not balanced at all. the last part will throw an exception because there are no parameters to training. i add some unit test for this function, and i will fix it later if anyone need it. --------- Co-authored-by: Olatunji Ruwase <[email protected]>
- Loading branch information
Showing
2 changed files
with
62 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from deepspeed.runtime import utils as ds_utils | ||
|
||
|
||
def check_partition(weights, num_parts, target_diff): | ||
result = ds_utils.partition_balanced(weights=weights, num_parts=num_parts) | ||
|
||
parts_sum = [] | ||
for b, e in zip(result[:-1], result[1:]): | ||
parts_sum.append(sum(weights[b:e])) | ||
|
||
assert max(parts_sum) - min( | ||
parts_sum | ||
) == target_diff, f"ds_utils.partition_balanced(weights={weights}, num_parts={num_parts}) return {result}" | ||
|
||
|
||
def test_partition_balanced(): | ||
check_partition([1, 2, 1], 4, target_diff=2) | ||
check_partition([1, 1, 1, 1], 4, target_diff=0) | ||
check_partition([1, 1, 1, 1, 1], 4, target_diff=1) | ||
check_partition([1, 1, 1, 1, 0, 1], 4, target_diff=1) |