Skip to content

Commit

Permalink
Fix dtype in parallel Kalman filter likelihood.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 584146012
  • Loading branch information
emilyfertig authored and tensorflower-gardener committed Nov 20, 2023
1 parent 0a66d6f commit 300bfe5
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,8 @@ def _mvn_log_prob(mean, covariance, y):
log_prob = log_prob - 0.5 * linalg.hpsd_logdet(
covariance, cholesky_matrix=cholesky_matrix)
event_dims = ps.shape(mean)[-1]
return log_prob - 0.5 * event_dims * dtype_util.as_numpy_dtype(
mean.dtype)(np.log(2 * np.pi))
return log_prob - dtype_util.as_numpy_dtype(mean.dtype)(
0.5 * event_dims * np.log(2 * np.pi))


def _extract_batch_shape(x, sample_ndims, event_ndims):
Expand Down

0 comments on commit 300bfe5

Please sign in to comment.