diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 48ccdbc29bf6..bc7a782e590c 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -14,7 +14,6 @@ import psutil import gc from math import sqrt -from bisect import bisect_left from packaging import version as pkg_version import torch @@ -570,67 +569,43 @@ def partition_uniform(num_items, num_parts): return parts -def _lprobe(weights, num_parts, bottleneck): - num_items = len(weights) - total_weight = weights[-1] - - # initialize partitioning - parts = [0] * (num_parts + 1) - for p in range(1, num_parts + 1): - parts[p] = num_items - - bsum = bottleneck # running sum of target weight for pth partition - chunksize = num_items // num_parts - step = chunksize - for p in range(1, num_parts): - # Jump to the next bucket - while (step < num_items) and (weights[step] < bsum): - step += chunksize - - # Find the end index of partition p - parts[p] = bisect_left(weights, bsum, lo=step - chunksize, hi=min(step, num_items)) - # Nothing more to partition, return early - if parts[p] == num_items: - # See if the current partition is overweight. - part_size = weights[-1] - weights[parts[p - 1]] - return parts, part_size < bottleneck - - # Next partition target - bsum = weights[parts[p] - 1] + bottleneck - - return parts, bsum >= total_weight - - -def _rb_partition_balanced(weights, num_parts, eps): - total_weight = weights[-1] - lower = total_weight / num_parts # best case heaviest partition - upper = total_weight # worst case heaviest partition - - # Do a binary search for the best partitioning - while upper > lower + eps: - mid = lower + ((upper - lower) / 2) - parts, success = _lprobe(weights, num_parts, mid) - if success: - upper = mid - else: - lower = mid + eps - return upper - - -def partition_balanced(weights, num_parts, eps=1e-3): - num_items = len(weights) - # First check for the trivial edge case - if num_items <= num_parts: - return partition_uniform(num_items, num_parts) - - weights_ = prefix_sum_inc(weights) - - # Find the smallest bottleneck (weight of heaviest partition) - bottleneck = _rb_partition_balanced(weights_, num_parts, eps=eps) - - # Now compute that partitioning - parts, success = _lprobe(weights_, num_parts, bottleneck) - assert success +def partition_balanced(weights, num_parts): + """ + use dynamic programming solve `The Linear Partition Problem`. + see https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM + """ + import numpy as np + n = len(weights) + m = num_parts + + if n <= m: + return partition_uniform(n, m) + + dp_max = np.full((n + 1, m + 1), np.inf) + dp_min = np.full((n + 1, m + 1), np.inf) + dp_cost = np.full((n + 1, m + 1), np.inf) + position = np.zeros((n + 1, m + 1), dtype=int) + prefix_sum = np.zeros((n + 1)) + prefix_sum[1:] = np.cumsum(weights) + + dp_max[0, 0] = 0 + dp_cost[0, 0] = 0 + for i in range(1, n + 1): + for j in range(1, min(i, m) + 1): + for k in range(i): + max_sum = max(dp_max[k, j - 1], prefix_sum[i] - prefix_sum[k]) + min_sum = min(dp_min[k, j - 1], prefix_sum[i] - prefix_sum[k]) + cost = max_sum - min_sum + if dp_cost[i, j] >= cost: + dp_cost[i, j] = cost + dp_max[i, j] = max_sum + dp_min[i, j] = min_sum + position[i, j] = k + + parts = [n] + for i in reversed(range(1, m + 1)): + parts.append(position[parts[-1], i]) + parts.reverse() return parts diff --git a/tests/unit/utils/test_partition_balanced.py b/tests/unit/utils/test_partition_balanced.py new file mode 100644 index 000000000000..e7285e478c53 --- /dev/null +++ b/tests/unit/utils/test_partition_balanced.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime import utils as ds_utils + + +def check_partition(weights, num_parts, target_diff): + result = ds_utils.partition_balanced(weights=weights, num_parts=num_parts) + + parts_sum = [] + for b, e in zip(result[:-1], result[1:]): + parts_sum.append(sum(weights[b:e])) + + assert max(parts_sum) - min( + parts_sum + ) == target_diff, f"ds_utils.partition_balanced(weights={weights}, num_parts={num_parts}) return {result}" + + +def test_partition_balanced(): + check_partition([1, 2, 1], 4, target_diff=2) + check_partition([1, 1, 1, 1], 4, target_diff=0) + check_partition([1, 1, 1, 1, 1], 4, target_diff=1) + check_partition([1, 1, 1, 1, 0, 1], 4, target_diff=1)