Skip to content

Commit

Permalink
Bug fix in MultiTask GP Regression Model related to caching Cholesky …
Browse files Browse the repository at this point in the history
…factors when observations are missing.

PiperOrigin-RevId: 563520139
  • Loading branch information
emilyfertig authored and tensorflower-gardener committed Sep 7, 2023
1 parent b23a82b commit adbcc1d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -613,48 +613,46 @@ def precompute_regression_model(
if _precomputed_divisor_matrix_cholesky is not None:
observation_scale = _scale_from_precomputed(
_precomputed_divisor_matrix_cholesky, kernel)
elif observations_is_missing is not None:
# If observations are missing, there's nothing we can do to preserve the
# operator structure, so densify.

observation_covariance = kernel.matrix_over_all_tasks(
observation_index_points, observation_index_points).to_dense()

if observation_noise_variance is not None:
broadcast_shape = distribution_util.get_broadcast_shape(
observation_covariance, observation_noise_variance[
..., tf.newaxis, tf.newaxis])
observation_covariance = tf.broadcast_to(observation_covariance,
broadcast_shape)
observation_covariance = _add_diagonal_shift(
observation_covariance, observation_noise_variance)
vec_observations_is_missing = _vec(observations_is_missing)
observation_covariance = tf.linalg.LinearOperatorFullMatrix(
psd_kernels_util.mask_matrix(
observation_covariance,
is_missing=vec_observations_is_missing),
is_non_singular=True,
is_positive_definite=True)
observation_scale = cholesky_util.cholesky_from_fn(
observation_covariance, cholesky_fn)
solve_on_observations = _precomputed_solve_on_observation
else:
observation_scale = mtgp._compute_flattened_scale( # pylint:disable=protected-access
kernel=kernel,
index_points=observation_index_points,
cholesky_fn=cholesky_fn,
observation_noise_variance=observation_noise_variance)

# Note that the conditional mean is
# k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter
# term since it won't change per iteration.
vec_diff = _vec(observations - mean_fn(observation_index_points))

if observations_is_missing is not None:
vec_diff = tf.where(vec_observations_is_missing,
tf.zeros([], dtype=vec_diff.dtype),
vec_diff)
solve_on_observations = _precomputed_solve_on_observation
if solve_on_observations is None:
# Note that the conditional mean is
# k(x, o) @ (k(o, o) + sigma**2)^-1 obs. We can precompute the latter
# term since it won't change per iteration.
vec_diff = _vec(observations - mean_fn(observation_index_points))

if observations_is_missing is not None:
# If observations are missing, there's nothing we can do to preserve
# the operator structure, so densify.
vec_observations_is_missing = _vec(observations_is_missing)
vec_diff = tf.where(vec_observations_is_missing,
tf.zeros([], dtype=vec_diff.dtype),
vec_diff)

observation_covariance = kernel.matrix_over_all_tasks(
observation_index_points, observation_index_points).to_dense()

if observation_noise_variance is not None:
broadcast_shape = distribution_util.get_broadcast_shape(
observation_covariance, observation_noise_variance[
..., tf.newaxis, tf.newaxis])
observation_covariance = tf.broadcast_to(observation_covariance,
broadcast_shape)
observation_covariance = _add_diagonal_shift(
observation_covariance, observation_noise_variance)
observation_covariance = tf.linalg.LinearOperatorFullMatrix(
psd_kernels_util.mask_matrix(
observation_covariance,
is_missing=vec_observations_is_missing),
is_non_singular=True,
is_positive_definite=True)
observation_scale = cholesky_util.cholesky_from_fn(
observation_covariance, cholesky_fn)
else:
observation_scale = mtgp._compute_flattened_scale( # pylint:disable=protected-access
kernel=kernel,
index_points=observation_index_points,
cholesky_fn=cholesky_fn,
observation_noise_variance=observation_noise_variance)
solve_on_observations = observation_scale.solvevec(
observation_scale.solvevec(vec_diff), adjoint=True)

Expand All @@ -678,6 +676,7 @@ def flattened_conditional_mean_fn(x):
observation_noise_variance=observation_noise_variance,
predictive_noise_variance=predictive_noise_variance,
cholesky_fn=cholesky_fn,
observations_is_missing=observations_is_missing,
_flattened_conditional_mean_fn=flattened_conditional_mean_fn,
_observation_scale=observation_scale,
validate_args=validate_args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,16 +474,26 @@ def testMeanVarianceJit(self):
tf.function(jit_compile=True)(mtgprm.mean)()
tf.function(jit_compile=True)(mtgprm.variance)()

def testMeanVarianceAndCovariancePrecomputed(self):
@parameterized.parameters(True, False)
def testMeanVarianceAndCovariancePrecomputed(self, has_missing_observations):
num_tasks = 3
num_obs = 7
amplitude = np.array([1., 2.], np.float64).reshape([2, 1])
length_scale = np.array([.1, .2, .3], np.float64).reshape([1, 3])
observation_noise_variance = np.array([1e-9], np.float64)

observation_index_points = (
np.random.uniform(-1., 1., (1, 1, 7, 2)).astype(np.float64))
np.random.uniform(-1., 1., (1, 1, num_obs, 2)).astype(np.float64))
observations = np.linspace(
-20., 20., 7 * num_tasks).reshape(7, num_tasks).astype(np.float64)
-20., 20., num_obs * num_tasks).reshape(
num_obs, num_tasks).astype(np.float64)

if has_missing_observations:
observations_is_missing = np.stack(
[np.random.randint(2, size=(num_obs,))] * num_tasks, axis=-1
).astype(np.bool_)
else:
observations_is_missing = None

index_points = np.random.uniform(-1., 1., (6, 2)).astype(np.float64)

Expand All @@ -497,6 +507,7 @@ def testMeanVarianceAndCovariancePrecomputed(self):
observation_index_points=observation_index_points,
observations=observations,
observation_noise_variance=observation_noise_variance,
observations_is_missing=observations_is_missing,
validate_args=True)

precomputed_mtgprm = mtgprm_lib.MultiTaskGaussianProcessRegressionModel.precompute_regression_model(
Expand All @@ -505,6 +516,7 @@ def testMeanVarianceAndCovariancePrecomputed(self):
observation_index_points=observation_index_points,
observations=observations,
observation_noise_variance=observation_noise_variance,
observations_is_missing=observations_is_missing,
validate_args=True)

mock_cholesky_fn = mock.Mock(return_value=None)
Expand All @@ -514,6 +526,7 @@ def testMeanVarianceAndCovariancePrecomputed(self):
observation_index_points=observation_index_points,
observations=observations,
observation_noise_variance=observation_noise_variance,
observations_is_missing=observations_is_missing,
_precomputed_divisor_matrix_cholesky=precomputed_mtgprm._precomputed_divisor_matrix_cholesky,
_precomputed_solve_on_observation=precomputed_mtgprm._precomputed_solve_on_observation,
cholesky_fn=mock_cholesky_fn,
Expand Down

0 comments on commit adbcc1d

Please sign in to comment.