diff --git a/tools/rebalance-corenrn-data.py b/tools/rebalance-corenrn-data.py index eafef0fd..5d42a6af 100755 --- a/tools/rebalance-corenrn-data.py +++ b/tools/rebalance-corenrn-data.py @@ -12,6 +12,7 @@ import heapq import itertools import logging +import math import os import sys @@ -113,12 +114,17 @@ def batch(iterable, first=0): yield group + [CORENRN_SKIP_MARK] * (ranks_per_machine - len(group)) break yield group - first, last = last, last + 40 + first, last = last, last + ranks_per_machine group = iterable[first:last] + # compute max number of cell groups per rank so we know the n_files in the header + max_len = max(len(m) for m in buckets) + max_groups_rank = math.ceil(max_len / ranks_per_machine) + total_entries = max_groups_rank * ranks_per_machine * len(buckets) + with open(output_file, "w") as out: print(infos["version"], file=out) - print(infos["n_files"], file=out) + print(total_entries, file=out) for buckets in itertools.zip_longest(*[batch(m) for m in buckets]): for entries in buckets: