diff --git a/lib/gpt/core/auto_tune.py b/lib/gpt/core/auto_tune.py index e4c27018..57504074 100644 --- a/lib/gpt/core/auto_tune.py +++ b/lib/gpt/core/auto_tune.py @@ -20,6 +20,24 @@ import json, hashlib, os, sys +runtime_tune_cache = {} + + +def save_cache(fn, data): + if g.rank() == 0: + os.makedirs(".gpt_auto_tune", exist_ok=True) + fout = open(fn, "wt") + json.dump(data, fout) + fout.close() + runtime_tune_cache[fn] = data + + +def load_cache(fn): + if fn not in runtime_tune_cache: + runtime_tune_cache[fn] = json.load(open(fn, "rt")) + return runtime_tune_cache[fn] + + def auto_tuned_method(method): def wrapper(self, *args): if not self.at_active: @@ -50,11 +68,7 @@ def wrapper(self, *args): imin = g.broadcast(0, imin) self.at_tuned_params = {"tag": self.at_tag, "params": self.at_params[imin], "results": dts} g.message(f"Tuning result (use {self.at_tuned_params} will be saved in {self.at_fn}") - if g.rank() == 0: - os.makedirs(".gpt_auto_tune", exist_ok=True) - fout = open(self.at_fn, "wt") - json.dump(self.at_tuned_params, fout) - fout.close() + save_cache(self.at_fn, self.at_tuned_params) g.barrier() @@ -76,7 +90,7 @@ def __init__(self, tag, params, default_param): self.at_fn = f".gpt_auto_tune/{hash_tag}.json" if os.path.exists(self.at_fn) and self.at_active: - self.at_tuned_params = json.load(open(self.at_fn, "rt")) + self.at_tuned_params = load_cache(self.at_fn) assert self.at_tuned_params["tag"] == tag if self.at_verbose: g.message(f"Use tuned results from {self.at_fn}")