Skip to content

Commit

Permalink
strike a compromise for greta-dev#736
Browse files Browse the repository at this point in the history
  • Loading branch information
njtierney committed Nov 7, 2024
1 parent 80192ed commit 6d0f9fd
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
26 changes: 24 additions & 2 deletions R/inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ greta_stash$numerical_messages <- c(
#' argument `trace_batch_size` can be modified to trade-off speed against
#' memory usage.
#'
#' @note to set a seed with MCMC you must use [tensorflow::set_random_seed()].
#' This is due to an internal API with tensorflow. See \url{https://github.com/greta-dev/greta/issues/559} for a thread exploring this.
#' @note to set a seed with MCMC you can use [set.seed()], or
#' [tensorflow::set_random_seed()]. They both given identical results. See
#' examples below.
#'
#' @return `mcmc`, `stashed_samples` & `extra_samples` - a
#' `greta_mcmc_list` object that can be analysed using functions from the
Expand Down Expand Up @@ -183,6 +184,27 @@ greta_stash$numerical_messages <- c(
#' m3 <- model(params)
#' o <- opt(m3, hessian = TRUE)
#' o$hessian
#'
#' # using set.seed or tensorflow::set_random_seed to set RNG for MCMC
#' a <- normal(0, 1)
#' y <- normal(a, 1)
#' m <- model(y)
#'
#' set.seed(12345)
#' one <- mcmc(m, n_samples = 1, chains = 1)
#' set.seed(12345)
#' two <- mcmc(m, n_samples = 1, chains = 1)
#' # same
#' all.equal(as.numeric(one), as.numeric(two))
#' tensorflow::set_random_seed(12345)
#' one_tf <- mcmc(m, n_samples = 1, chains = 1)
#' tensorflow::set_random_seed(12345)
#' two_tf <- mcmc(m, n_samples = 1, chains = 1)
#' # same
#' all.equal(as.numeric(one_tf), as.numeric(two_tf))
#' # different
#' all.equal(as.numeric(one), as.numeric(one_tf))
#'
#' }
mcmc <- function(
model,
Expand Down
26 changes: 24 additions & 2 deletions man/inference.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 20 additions & 1 deletion tests/testthat/test_seed.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ test_that("calculate samples are the same when the R seed is the same", {
)
})

test_that("mcmc samples are the same when the R seed is the same", {
test_that("mcmc samples are the same when the R seed is the same, also with tf set seed", {
skip_if_not(check_tf_version())
a <- normal(0, 1)
y <- normal(a, 1)
Expand All @@ -123,6 +123,25 @@ test_that("mcmc samples are the same when the R seed is the same", {
as.numeric(one),
as.numeric(two)
)

tensorflow::set_random_seed(12345)
one_tf <- mcmc(m, warmup = 10, n_samples = 1, chains = 1)
tensorflow::set_random_seed(12345)
two_tf <- mcmc(m, warmup = 10, n_samples = 1, chains = 1)

expect_equal(
as.numeric(one_tf),
as.numeric(two_tf)
)

# but these are not (always) equal to each other
mcmc_matches_tf_one <- identical(as.numeric(one),as.numeric(one_tf))
mcmc_matches_tf_two <- identical(as.numeric(two),as.numeric(two_tf))

expect_false(mcmc_matches_tf_one)

expect_false(mcmc_matches_tf_two)

})

test_that("simulate uses the local RNG seed", {
Expand Down

0 comments on commit 6d0f9fd

Please sign in to comment.