diff --git a/neural_tangents/tests/predict_test.py b/neural_tangents/tests/predict_test.py index b0720190..561fba69 100644 --- a/neural_tangents/tests/predict_test.py +++ b/neural_tangents/tests/predict_test.py @@ -217,8 +217,8 @@ def _get_inputs(cls, out_logits, test_shape, train_shape): for out_logits in OUTPUT_LOGITS for name, fn in KERNELS.items() for momentum in [None, 0.9] - for learning_rate in [0.00001] - for t in [100] + for learning_rate in [0.0001] + for t in [10] for loss in ['mse_analytic', 'mse']) ) def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits, @@ -653,7 +653,7 @@ def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): training_steps = 1000 learning_rate = 0.1 - ensemble_size = 2048 + ensemble_size = 1024 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(), @@ -686,7 +686,7 @@ def train_network(key): return get_params(opt_state) params = vmap(train_network)(ensemble_key) - rtol = 0.05 + rtol = 0.08 for x in [None, 'x_test']: with self.subTest(x=x): @@ -741,49 +741,49 @@ def kernel_fn(x1, x2, get): 'nngp', 'ntk', ('nngp',), ('ntk',), ('nngp', 'ntk'), ('ntk', 'nngp')]: - k_dd = kernel_fn(x_train, None, get) - - gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg) - gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, - x_train, - y_train, - diag_reg=reg) - for x_test in [None, 'x_test']: - x_test = None if x_test is None else random.normal(key, (8, 2)) - k_td = None if x_test is None else kernel_fn(x_test, x_train, get) - - for compute_cov in [True, False]: - with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic, - get=get, - x_test=x_test if x_test is None else 'x_test', - compute_cov=compute_cov): - if compute_cov: - nngp_tt = (True if x_test is None else - kernel_fn(x_test, None, 'nngp')) - else: - nngp_tt = None - - out_ens = gd_ensemble(None, x_test, get, compute_cov) - out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov) - self._assertAllClose(out_ens_inf, out_ens, 0.08) - - if (get is not None and - 'nngp' not in get and - compute_cov and - k_td is not None): - with self.assertRaises(ValueError): - out_gp_inf = gp_inference(get=get, k_test_train=k_td, - nngp_test_test=nngp_tt) - else: + k_dd = kernel_fn(x_train, None, get) + + gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg) + gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, + x_train, + y_train, + diag_reg=reg) + for x_test in [None, 'x_test']: + x_test = None if x_test is None else random.normal(key, (8, 2)) + k_td = None if x_test is None else kernel_fn(x_test, x_train, get) + + for compute_cov in [True, False]: + with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic, + get=get, + x_test=x_test if x_test is None else 'x_test', + compute_cov=compute_cov): + if compute_cov: + nngp_tt = (True if x_test is None else + kernel_fn(x_test, None, 'nngp')) + else: + nngp_tt = None + + out_ens = gd_ensemble(None, x_test, get, compute_cov) + out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov) + self._assertAllClose(out_ens_inf, out_ens, 0.08) + + if (get is not None and + 'nngp' not in get and + compute_cov and + k_td is not None): + with self.assertRaises(ValueError): out_gp_inf = gp_inference(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) - self.assertAllClose(out_ens, out_gp_inf) + else: + out_gp_inf = gp_inference(get=get, k_test_train=k_td, + nngp_test_test=nngp_tt) + self.assertAllClose(out_ens, out_gp_inf) def testPredictOnCPU(self): - x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3)) - x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3)) + x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2)) + x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2)) - y_train = random.uniform(random.PRNGKey(1), (10, 7)) + y_train = random.uniform(random.PRNGKey(1), (4, 2)) _, _, kernel_fn = stax.serial( stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) @@ -814,7 +814,7 @@ def testPredictOnCPU(self): self.assertEqual(on_cpu, utils.is_on_cpu(predict_none)) def testPredictND(self): - n_chan = 7 + n_chan = 6 key = random.PRNGKey(1) im_shape = (5, 4, 3) n_train = 2