Skip to content

Commit

Permalink
Squeeze a flaky test to make it run in 15 minutes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 318207007
  • Loading branch information
romanngg committed Jun 25, 2020
1 parent 689b0be commit 1948b32
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions neural_tangents/tests/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def _get_inputs(cls, out_logits, test_shape, train_shape):
for out_logits in OUTPUT_LOGITS
for name, fn in KERNELS.items()
for momentum in [None, 0.9]
for learning_rate in [0.00001]
for t in [100]
for learning_rate in [0.0001]
for t in [10]
for loss in ['mse_analytic', 'mse'])
)
def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits,
Expand Down Expand Up @@ -653,7 +653,7 @@ def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
out_logits):
training_steps = 1000
learning_rate = 0.1
ensemble_size = 2048
ensemble_size = 1024

init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(),
Expand Down Expand Up @@ -686,7 +686,7 @@ def train_network(key):
return get_params(opt_state)

params = vmap(train_network)(ensemble_key)
rtol = 0.05
rtol = 0.08

for x in [None, 'x_test']:
with self.subTest(x=x):
Expand Down Expand Up @@ -741,49 +741,49 @@ def kernel_fn(x1, x2, get):
'nngp', 'ntk',
('nngp',), ('ntk',),
('nngp', 'ntk'), ('ntk', 'nngp')]:
k_dd = kernel_fn(x_train, None, get)

gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg)
gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
x_train,
y_train,
diag_reg=reg)
for x_test in [None, 'x_test']:
x_test = None if x_test is None else random.normal(key, (8, 2))
k_td = None if x_test is None else kernel_fn(x_test, x_train, get)

for compute_cov in [True, False]:
with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic,
get=get,
x_test=x_test if x_test is None else 'x_test',
compute_cov=compute_cov):
if compute_cov:
nngp_tt = (True if x_test is None else
kernel_fn(x_test, None, 'nngp'))
else:
nngp_tt = None

out_ens = gd_ensemble(None, x_test, get, compute_cov)
out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov)
self._assertAllClose(out_ens_inf, out_ens, 0.08)

if (get is not None and
'nngp' not in get and
compute_cov and
k_td is not None):
with self.assertRaises(ValueError):
out_gp_inf = gp_inference(get=get, k_test_train=k_td,
nngp_test_test=nngp_tt)
else:
k_dd = kernel_fn(x_train, None, get)

gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg)
gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
x_train,
y_train,
diag_reg=reg)
for x_test in [None, 'x_test']:
x_test = None if x_test is None else random.normal(key, (8, 2))
k_td = None if x_test is None else kernel_fn(x_test, x_train, get)

for compute_cov in [True, False]:
with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic,
get=get,
x_test=x_test if x_test is None else 'x_test',
compute_cov=compute_cov):
if compute_cov:
nngp_tt = (True if x_test is None else
kernel_fn(x_test, None, 'nngp'))
else:
nngp_tt = None

out_ens = gd_ensemble(None, x_test, get, compute_cov)
out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov)
self._assertAllClose(out_ens_inf, out_ens, 0.08)

if (get is not None and
'nngp' not in get and
compute_cov and
k_td is not None):
with self.assertRaises(ValueError):
out_gp_inf = gp_inference(get=get, k_test_train=k_td,
nngp_test_test=nngp_tt)
self.assertAllClose(out_ens, out_gp_inf)
else:
out_gp_inf = gp_inference(get=get, k_test_train=k_td,
nngp_test_test=nngp_tt)
self.assertAllClose(out_ens, out_gp_inf)

def testPredictOnCPU(self):
x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3))
x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3))
x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2))
x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2))

y_train = random.uniform(random.PRNGKey(1), (10, 7))
y_train = random.uniform(random.PRNGKey(1), (4, 2))

_, _, kernel_fn = stax.serial(
stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1))
Expand Down Expand Up @@ -814,7 +814,7 @@ def testPredictOnCPU(self):
self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))

def testPredictND(self):
n_chan = 7
n_chan = 6
key = random.PRNGKey(1)
im_shape = (5, 4, 3)
n_train = 2
Expand Down

0 comments on commit 1948b32

Please sign in to comment.