Skip to content

Commit

Permalink
Improve saving
Browse files Browse the repository at this point in the history
  • Loading branch information
Frederike Duembgen committed Jan 17, 2024
1 parent dd46d77 commit 4213f87
Showing 1 changed file with 102 additions and 81 deletions.
183 changes: 102 additions & 81 deletions _scripts/run_time_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from _scripts.run_clique_study import read_saved_learner
from auto_template.sim_experiments import create_newinstance
from lifters.matweight_lifter import MatWeightLocLifter
from utils.plotting_tools import savefig

NOISE_SEED = 0

Expand All @@ -24,6 +25,9 @@
USE_FUSION = False
VERBOSE = False

# USE_METHODS = ["SDP", "dSDP", "ADMM"]
USE_METHODS = ["dSDP"]


def extract_solution(lifter: MatWeightLocLifter, X_list):
x_dim = lifter.node_size()
Expand All @@ -37,7 +41,7 @@ def extract_solution(lifter: MatWeightLocLifter, X_list):
return np.vstack(x_all), np.mean(evr_mean)


def generate_results(lifter: MatWeightLocLifter, n_params_list=[10]):
def generate_results(lifter: MatWeightLocLifter, n_params_list=[10], fname=""):
saved_learner = read_saved_learner(lifter)

df_data = []
Expand All @@ -50,6 +54,8 @@ def generate_results(lifter: MatWeightLocLifter, n_params_list=[10]):
Q, y = new_lifter.get_Q()

theta_gt = new_lifter.get_vec_around_gt(delta=0)

print("solving local...", end="")
t1 = time.time()
theta_est, info, cost = new_lifter.local_solver(theta_gt, lifter.y_)
time_dict["t local"] = time.time() - t1
Expand All @@ -66,80 +72,91 @@ def generate_results(lifter: MatWeightLocLifter, n_params_list=[10]):
time_dict["dim dSDP"] = clique_list[0].Q.shape[0]
time_dict["m dSDP"] = sum(len(c.A_list) for c in clique_list)

print("solving dSDP...", end="")
t1 = time.time()
X_list, info = solve_oneshot(
clique_list,
use_primal=True,
use_fusion=True,
verbose=VERBOSE,
)
time_dict["t dSDP"] = time.time() - t1
time_dict["cost dSDP"] = info["cost"]
if "dSDP" in USE_METHODS:
print("solving dSDP...", end="")
t1 = time.time()
X_list, info = solve_oneshot(
clique_list,
use_primal=True,
use_fusion=True,
verbose=VERBOSE,
)
time_dict["t dSDP"] = time.time() - t1
time_dict["cost dSDP"] = info["cost"]

x_dSDP, evr_mean = extract_solution(new_lifter, X_list)
time_dict["evr dSDP"] = evr_mean
x_dSDP, evr_mean = extract_solution(new_lifter, X_list)
time_dict["evr dSDP"] = evr_mean

print("running ADMM...", end="")
t1 = time.time()
# X0 = []
# for c in clique_list:
# x_clique = new_lifter.get_x(theta=theta_est, var_subset=c.var_dict)
# X0.append(np.outer(x_clique, x_clique))
X0 = None
print("target", info["cost"])
X_list, info = solve_alternating(
clique_list,
X0=X0,
use_fusion=False,
rho_start=1e4,
verbose=True,
early_stop=False,
max_iter=20,
mu_rho=2.0,
tau_rho=2.0,
)
print(info["msg"], end="...")
time_dict["t ADMM"] = time.time() - t1
time_dict["cost ADMM"] = info["cost"]
if "ADMM" in USE_METHODS:
print("running ADMM...", end="")
t1 = time.time()
# X0 = []
# for c in clique_list:
# x_clique = new_lifter.get_x(theta=theta_est, var_subset=c.var_dict)
# X0.append(np.outer(x_clique, x_clique))
X0 = None
print("target", info["cost"])
X_list, info = solve_alternating(
clique_list,
X0=X0,
use_fusion=False,
rho_start=1e4,
verbose=True,
early_stop=False,
max_iter=10,
mu_rho=2.0,
tau_rho=2.0,
)
print(info["msg"], end="...")
time_dict["t ADMM"] = time.time() - t1
time_dict["cost ADMM"] = info["cost"]

x_ADMM, evr_mean = extract_solution(new_lifter, X_list)
time_dict["evr ADMM"] = evr_mean
x_ADMM, evr_mean = extract_solution(new_lifter, X_list)
time_dict["evr ADMM"] = evr_mean

print("creating constraints...", end="")
t1 = time.time()
if USE_AUTOTEMPLATE:
new_constraints = new_lifter.apply_templates(saved_learner.templates, 0)
Constraints = [(new_lifter.get_A0(), 1.0)] + [
(c.A_sparse_, 0.0) for c in new_constraints
]
else:
if USE_KNOWN:
if "SDP" in USE_METHODS:
print("creating constraints...", end="")
t1 = time.time()
if USE_AUTOTEMPLATE:
new_constraints = new_lifter.apply_templates(saved_learner.templates, 0)
Constraints = [(new_lifter.get_A0(), 1.0)] + [
(A, 0.0) for A in new_lifter.get_A_known()
(c.A_sparse_, 0.0) for c in new_constraints
]
else:
Constraints = [(new_lifter.get_A0(), 1.0)] + [
(A, 0.0) for A in new_lifter.get_A_learned_simple()
]
time_dict["t create constraints"] = time.time() - t1

if True: # n_params <= 18:
time_dict["dim SDP"] = Q.shape[0]
time_dict["m SDP"] = len(Constraints)

print("solving SDP...", end="")
t1 = time.time()
X, info = solve_sdp(
Q, Constraints, adjust=ADJUST, primal=USE_PRIMAL, use_fusion=USE_FUSION
)
time_dict["t SDP"] = time.time() - t1
time_dict["cost SDP"] = info["cost"]

x_SDP, info = rank_project(X, p=1)
time_dict["evr SDP"] = info["EVR"]
if USE_KNOWN:
Constraints = [(new_lifter.get_A0(), 1.0)] + [
(A, 0.0) for A in new_lifter.get_A_known()
]
else:
Constraints = [(new_lifter.get_A0(), 1.0)] + [
(A, 0.0) for A in new_lifter.get_A_learned_simple()
]
time_dict["t create constraints"] = time.time() - t1

if True: # n_params <= 18:
time_dict["dim SDP"] = Q.shape[0]
time_dict["m SDP"] = len(Constraints)

print("solving SDP...", end="")
t1 = time.time()
X, info = solve_sdp(
Q,
Constraints,
adjust=ADJUST,
primal=USE_PRIMAL,
use_fusion=USE_FUSION,
)
time_dict["t SDP"] = time.time() - t1
time_dict["cost SDP"] = info["cost"]

x_SDP, info = rank_project(X, p=1)
time_dict["evr SDP"] = info["EVR"]
print("done.")
df_data.append(time_dict)
if fname != "":
df = pd.DataFrame(df_data)
df.to_pickle(fname)
print("saved intermediate as", fname)
return pd.DataFrame(df_data)


Expand All @@ -149,29 +166,30 @@ def generate_results(lifter: MatWeightLocLifter, n_params_list=[10]):
lifter.ALL_PAIRS = False
lifter.CLIQUE_SIZE = 2

# n_params_list = np.logspace(1, 6, 11).astype(int)
n_params_list = np.logspace(1, 2, 10).astype(int)
# n_params_list = np.arange(10, 101, step=10).astype(int)
# n_params_list = np.arange(10, 19) # runs out of memory at 19!
fname = f"_results/{lifter}_time_test.pkl"
n_params_list = np.logspace(1, 6, 6).astype(int)
# n_params_list = np.logspace(1, 2, 10).astype(int)
fname = f"_results/{lifter}_time_dsdp.pkl"
overwrite = True

try:
assert overwrite is False
df = pd.read_pickle(fname)
except (FileNotFoundError, AssertionError):
df = generate_results(lifter, n_params_list=n_params_list)
df = generate_results(lifter, n_params_list=n_params_list, fname=fname)
df.to_pickle(fname)
print("saved final as", fname)

for label, plot in zip(["t", "cost"], [sns.scatterplot, sns.barplot]):
value_vars = [
f"{label} local",
f"{label} dSDP",
f"{label} SDP",
f"{label} ADMM",
]
value_vars = set(value_vars).intersection(df.columns.unique())
df_long = df.melt(
id_vars=["n params"],
value_vars=[
f"{label} local",
f"{label} dSDP",
f"{label} SDP",
f"{label} ADMM",
],
value_vars=value_vars,
value_name=label,
var_name="solver type",
)
Expand All @@ -181,13 +199,16 @@ def generate_results(lifter: MatWeightLocLifter, n_params_list=[10]):
if label != "cost":
ax.set_xscale("log")
ax.grid("on")
savefig(fig, fname.replace(".pkl", f"_{label}.png"))

fig, ax = plt.subplots()
ax.loglog(df["n params"], df["evr SDP"], label="SDP")
ax.loglog(df["n params"], df["evr dSDP"], label="dSDP")
ax.loglog(df["n params"], df["evr ADMM"], label="ADMM")
value_vars = ["evr SDP", "evr dSDP", "evr ADMM"]
value_vars = set(value_vars).intersection(df.columns.unique())
for v in value_vars:
ax.loglog(df["n params"], df[v], label=v.strip("evr "))
ax.legend()
ax.grid("on")
ax.set_ylabel("EVR")
ax.set_xlabel("n params")
savefig(fig, fname.replace(".pkl", f"_evr.png"))
print("done")

0 comments on commit 4213f87

Please sign in to comment.