From 34ee5a7faba4e30c820bf51b998f22f9784347b9 Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Fri, 20 Sep 2024 20:06:29 +0300 Subject: [PATCH] auto-tune fix on multiple ranks --- lib/gpt/core/auto_tune.py | 19 +++++++++++++------ lib/gpt/core/mpi.py | 2 ++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/lib/gpt/core/auto_tune.py b/lib/gpt/core/auto_tune.py index d5f5224f..7a1b97ff 100644 --- a/lib/gpt/core/auto_tune.py +++ b/lib/gpt/core/auto_tune.py @@ -97,11 +97,18 @@ def __init__(self, tag, params, default_param): hash_tag = str(hashlib.sha256(tag.encode("utf-8")).hexdigest()) self.at_fn = f".gpt_auto_tune/{hash_tag}.json" - if os.path.exists(self.at_fn) and self.at_active: - self.at_tuned_params = load_cache(self.at_fn) - assert self.at_tuned_params["tag"] == tag - if self.at_verbose: - g.message(f"Use tuned results from {self.at_fn}") + if g.rank() == 0 and os.path.exists(self.at_fn) and self.at_active: + try: + self.at_tuned_params = load_cache(self.at_fn) + assert self.at_tuned_params["tag"] == tag + if self.at_verbose: + g.message(f"Use tuned results from {self.at_fn}") + except: + self.at_tuned_params = {} else: + self.at_tuned_params = {} + + self.at_tuned_params = g.broadcast(0, self.at_tuned_params) + + if len(self.at_tuned_params) == 0: self.at_tuned_params = None - # in future versions allow masking by tags diff --git a/lib/gpt/core/mpi.py b/lib/gpt/core/mpi.py index 437b7904..ff9a926e 100644 --- a/lib/gpt/core/mpi.py +++ b/lib/gpt/core/mpi.py @@ -38,6 +38,8 @@ def broadcast(root, data): return broadcast(root, data.encode("utf-8")).decode("utf-8") elif isinstance(data, np.ndarray): return pickle.loads(broadcast(root, pickle.dumps(data))) + elif isinstance(data, dict): + return pickle.loads(broadcast(root, pickle.dumps(data))) return cgpt.broadcast(root, data)