-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathget_run_score.py
31 lines (25 loc) · 923 Bytes
/
get_run_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# python get_run_score.py <ROOT:str> <EXP:str> <RUN:int> <STEP:int> <BATCH_SZ:int> <TYPE:str>
from data_diet.data import load_data
from data_diet.scores import compute_scores
from data_diet.utils import get_fn_params_state, load_args
import sys
import numpy as np
import os
ROOT = sys.argv[1]
EXP = sys.argv[2]
RUN = int(sys.argv[3])
STEP = int(sys.argv[4])
BATCH_SZ = int(sys.argv[5])
TYPE = sys.argv[6]
run_dir = ROOT + f'/exps/{EXP}/run_{RUN}'
args = load_args(run_dir)
args.load_dir = run_dir
args.ckpt = STEP
_, X, Y, _, _, args = load_data(args)
fn, params, state = get_fn_params_state(args)
scores = compute_scores(fn, params, state, X, Y, BATCH_SZ, TYPE)
path_name = 'error_l2_norm_scores' if TYPE == 'l2_error' else 'grad_norm_scores'
save_dir = run_dir + f'/{path_name}'
save_path = run_dir + f'/{path_name}/ckpt_{STEP}.npy'
if not os.path.exists(save_dir): os.makedirs(save_dir)
np.save(save_path, scores)