Skip to content

Commit

Permalink
[BUG] partition_balanced return wrong result. (#4312)
Browse files Browse the repository at this point in the history
# Background

In pipeline parallelism, deepspeed uses `ds_utils.partition_balanced` to
balance the partitioning of the model according to the number of
parameters or class names.

https://github.com/microsoft/DeepSpeed/blob/581e44dd1ab3c409a5905335867c761d5cb4db5b/deepspeed/runtime/pipe/module.py#L380-L395

# What wrong?
```
>>> import deepspeed
>>> deepspeed.__version__
'0.10.3+542dc0d5'
>>> from deepspeed.runtime import utils as ds_utils
>>> ds_utils.partition_balanced([1, 1, 1, 1, 1], 4)
[0, 2, 4, 5, 5]
>>> 
```
the result [0, 2, 4, 5, 5] means [2, 2, 1, 0] layers for each part,
which is not balanced at all. the last part will throw an exception
because there are no parameters to training.

i add some unit test for this function, and i will fix it later if
anyone need it.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
zjjMaiMai and tjruwase authored Dec 8, 2023
1 parent ce60708 commit 2bdf061
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 62 deletions.
99 changes: 37 additions & 62 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import psutil
import gc
from math import sqrt
from bisect import bisect_left
from packaging import version as pkg_version

import torch
Expand Down Expand Up @@ -570,67 +569,43 @@ def partition_uniform(num_items, num_parts):
return parts


def _lprobe(weights, num_parts, bottleneck):
num_items = len(weights)
total_weight = weights[-1]

# initialize partitioning
parts = [0] * (num_parts + 1)
for p in range(1, num_parts + 1):
parts[p] = num_items

bsum = bottleneck # running sum of target weight for pth partition
chunksize = num_items // num_parts
step = chunksize
for p in range(1, num_parts):
# Jump to the next bucket
while (step < num_items) and (weights[step] < bsum):
step += chunksize

# Find the end index of partition p
parts[p] = bisect_left(weights, bsum, lo=step - chunksize, hi=min(step, num_items))
# Nothing more to partition, return early
if parts[p] == num_items:
# See if the current partition is overweight.
part_size = weights[-1] - weights[parts[p - 1]]
return parts, part_size < bottleneck

# Next partition target
bsum = weights[parts[p] - 1] + bottleneck

return parts, bsum >= total_weight


def _rb_partition_balanced(weights, num_parts, eps):
total_weight = weights[-1]
lower = total_weight / num_parts # best case heaviest partition
upper = total_weight # worst case heaviest partition

# Do a binary search for the best partitioning
while upper > lower + eps:
mid = lower + ((upper - lower) / 2)
parts, success = _lprobe(weights, num_parts, mid)
if success:
upper = mid
else:
lower = mid + eps
return upper


def partition_balanced(weights, num_parts, eps=1e-3):
num_items = len(weights)
# First check for the trivial edge case
if num_items <= num_parts:
return partition_uniform(num_items, num_parts)

weights_ = prefix_sum_inc(weights)

# Find the smallest bottleneck (weight of heaviest partition)
bottleneck = _rb_partition_balanced(weights_, num_parts, eps=eps)

# Now compute that partitioning
parts, success = _lprobe(weights_, num_parts, bottleneck)
assert success
def partition_balanced(weights, num_parts):
"""
use dynamic programming solve `The Linear Partition Problem`.
see https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM
"""
import numpy as np
n = len(weights)
m = num_parts

if n <= m:
return partition_uniform(n, m)

dp_max = np.full((n + 1, m + 1), np.inf)
dp_min = np.full((n + 1, m + 1), np.inf)
dp_cost = np.full((n + 1, m + 1), np.inf)
position = np.zeros((n + 1, m + 1), dtype=int)
prefix_sum = np.zeros((n + 1))
prefix_sum[1:] = np.cumsum(weights)

dp_max[0, 0] = 0
dp_cost[0, 0] = 0
for i in range(1, n + 1):
for j in range(1, min(i, m) + 1):
for k in range(i):
max_sum = max(dp_max[k, j - 1], prefix_sum[i] - prefix_sum[k])
min_sum = min(dp_min[k, j - 1], prefix_sum[i] - prefix_sum[k])
cost = max_sum - min_sum
if dp_cost[i, j] >= cost:
dp_cost[i, j] = cost
dp_max[i, j] = max_sum
dp_min[i, j] = min_sum
position[i, j] = k

parts = [n]
for i in reversed(range(1, m + 1)):
parts.append(position[parts[-1], i])
parts.reverse()

return parts

Expand Down
25 changes: 25 additions & 0 deletions tests/unit/utils/test_partition_balanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed.runtime import utils as ds_utils


def check_partition(weights, num_parts, target_diff):
result = ds_utils.partition_balanced(weights=weights, num_parts=num_parts)

parts_sum = []
for b, e in zip(result[:-1], result[1:]):
parts_sum.append(sum(weights[b:e]))

assert max(parts_sum) - min(
parts_sum
) == target_diff, f"ds_utils.partition_balanced(weights={weights}, num_parts={num_parts}) return {result}"


def test_partition_balanced():
check_partition([1, 2, 1], 4, target_diff=2)
check_partition([1, 1, 1, 1], 4, target_diff=0)
check_partition([1, 1, 1, 1, 1], 4, target_diff=1)
check_partition([1, 1, 1, 1, 0, 1], 4, target_diff=1)

0 comments on commit 2bdf061

Please sign in to comment.