Skip to content

Commit

Permalink
auto-tune fix on multiple ranks
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Sep 20, 2024
1 parent 1c3f891 commit 34ee5a7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
19 changes: 13 additions & 6 deletions lib/gpt/core/auto_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,18 @@ def __init__(self, tag, params, default_param):
hash_tag = str(hashlib.sha256(tag.encode("utf-8")).hexdigest())

self.at_fn = f".gpt_auto_tune/{hash_tag}.json"
if os.path.exists(self.at_fn) and self.at_active:
self.at_tuned_params = load_cache(self.at_fn)
assert self.at_tuned_params["tag"] == tag
if self.at_verbose:
g.message(f"Use tuned results from {self.at_fn}")
if g.rank() == 0 and os.path.exists(self.at_fn) and self.at_active:
try:
self.at_tuned_params = load_cache(self.at_fn)
assert self.at_tuned_params["tag"] == tag
if self.at_verbose:
g.message(f"Use tuned results from {self.at_fn}")
except:

Check failure on line 106 in lib/gpt/core/auto_tune.py

View workflow job for this annotation

GitHub Actions / lint

E722:do not use bare 'except'
self.at_tuned_params = {}
else:
self.at_tuned_params = {}

self.at_tuned_params = g.broadcast(0, self.at_tuned_params)

if len(self.at_tuned_params) == 0:
self.at_tuned_params = None
# in future versions allow masking by tags
2 changes: 2 additions & 0 deletions lib/gpt/core/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def broadcast(root, data):
return broadcast(root, data.encode("utf-8")).decode("utf-8")
elif isinstance(data, np.ndarray):
return pickle.loads(broadcast(root, pickle.dumps(data)))
elif isinstance(data, dict):
return pickle.loads(broadcast(root, pickle.dumps(data)))
return cgpt.broadcast(root, data)


Expand Down

0 comments on commit 34ee5a7

Please sign in to comment.