Skip to content

Commit

Permalink
cache
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jun 4, 2024
1 parent 8987723 commit 343d98b
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions lib/gpt/core/auto_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@
import json, hashlib, os, sys


runtime_tune_cache = {}


def save_cache(fn, data):
if g.rank() == 0:
os.makedirs(".gpt_auto_tune", exist_ok=True)
fout = open(fn, "wt")
json.dump(data, fout)
fout.close()
runtime_tune_cache[fn] = data


def load_cache(fn):
if fn not in runtime_tune_cache:
runtime_tune_cache[fn] = json.load(open(fn, "rt"))
return runtime_tune_cache[fn]


def auto_tuned_method(method):
def wrapper(self, *args):
if not self.at_active:
Expand Down Expand Up @@ -50,11 +68,7 @@ def wrapper(self, *args):
imin = g.broadcast(0, imin)
self.at_tuned_params = {"tag": self.at_tag, "params": self.at_params[imin], "results": dts}
g.message(f"Tuning result (use {self.at_tuned_params} will be saved in {self.at_fn}")
if g.rank() == 0:
os.makedirs(".gpt_auto_tune", exist_ok=True)
fout = open(self.at_fn, "wt")
json.dump(self.at_tuned_params, fout)
fout.close()
save_cache(self.at_fn, self.at_tuned_params)

g.barrier()

Expand All @@ -76,7 +90,7 @@ def __init__(self, tag, params, default_param):

self.at_fn = f".gpt_auto_tune/{hash_tag}.json"
if os.path.exists(self.at_fn) and self.at_active:
self.at_tuned_params = json.load(open(self.at_fn, "rt"))
self.at_tuned_params = load_cache(self.at_fn)
assert self.at_tuned_params["tag"] == tag
if self.at_verbose:
g.message(f"Use tuned results from {self.at_fn}")
Expand Down

0 comments on commit 343d98b

Please sign in to comment.