Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mcse plots similar to neff and rhat plots #278

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method("[",mcse_ratio)
S3method("[",neff_ratio)
S3method("[",rhat)
S3method(log_posterior,CmdStanMCMC)
Expand Down Expand Up @@ -63,6 +64,9 @@ export(mcmc_hist)
export(mcmc_hist_by_chain)
export(mcmc_intervals)
export(mcmc_intervals_data)
export(mcmc_mcse)
export(mcmc_mcse_data)
export(mcmc_mcse_hist)
export(mcmc_neff)
export(mcmc_neff_data)
export(mcmc_neff_hist)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
* `mcmc_areas()` and `mcmc_areas_ridges()` gain an argument `border_size` for
controlling the thickness of the ridgelines. (#224)

* New plotting functions `mcmc_mcse()` and `mcmc_mcse_hist()` that are similar
to `mcmc_neff()` and `mcmc_neff_hist()` but for plotting ratios of MCSE to
posterior SD. (#278, @VeenDuco)

* New plotting function `ppc_km_overlay_grouped()`, the grouped variant of
`ppc_km_overlay()`. (#260, @fweber144)

Expand Down
182 changes: 164 additions & 18 deletions R/mcmc-diagnostics.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
#' General MCMC diagnostics
#'
#' Plots of Rhat statistics, ratios of effective sample size to total sample
#' size, and autocorrelation of MCMC draws. See the **Plot Descriptions**
#' section, below, for details. For models fit using the No-U-Turn-Sampler, see
#' also [MCMC-nuts] for additional MCMC diagnostic plots.
#' size, ratios of MCSE to posterior SD, and autocorrelation of MCMC draws. See
#' the **Plot Descriptions** section, below, for details. For models fit using
#' the No-U-Turn-Sampler, see also [MCMC-nuts] for additional MCMC diagnostic
#' plots.
#'
#' @name MCMC-diagnostics
#' @family MCMC
#'
#' @param ratio For effective sample size plots, a vector of *ratios* of
#' effective sample size estimates to total sample sizes (see [neff_ratio()]).
#' For MCSE plots, a vector of *ratios* of Monte Carlo standard errors to
#' posterior standard deviations.
#' @template args-hist
#' @param size An optional value to override [ggplot2::geom_point()]'s
#' default size (for `mcmc_rhat()`, `mcmc_neff()`) or
Expand All @@ -32,9 +37,19 @@
#' histogram. Values are colored using different shades (lighter is better).
#' The chosen thresholds are somewhat arbitrary, but can be useful guidelines
#' in practice.
#' * _light_: between 0.5 and 1 (high)
#' * _mid_: between 0.1 and 0.5 (good)
#' * _dark_: below 0.1 (low)
#' * _light_: between 0.5 and 1 (good)
#' * _mid_: between 0.1 and 0.5 (ok)
#' * _dark_: below 0.1 (too low)
#' }
#'
#' \item{`mcmc_mcse()`, `mcmc_mcse_hist()`}{
#' Ratios of Monte Carlo standard errors to posterior standard deviations as
#' either points or a histogram. Values are colored using different shades
#' (lighter is better). The chosen thresholds are somewhat arbitrary, but can
#' be useful guidelines in practice.
#' * _light_: below 0.05 (good)
#' * _mid_: between 0.05 and 0.1 (ok)
#' * _dark_: above 0.1 (too high)
#' }
#'
#' \item{`mcmc_acf()`, `mcmc_acf_bar()`}{
Expand Down Expand Up @@ -91,6 +106,11 @@
#' mcmc_neff_hist(ratio)
#' mcmc_neff(ratio)
#'
#' # fake mcse ratio values to use for demonstration
#' ratio <- c(runif(100, 0, 1.5))
#' mcmc_mcse_hist(ratio)
#' mcmc_mcse(ratio)
#'
#' \dontrun{
#' # Example using rstanarm model (requires rstanarm package)
#' library(rstanarm)
Expand Down Expand Up @@ -210,8 +230,6 @@ mcmc_rhat_data <- function(rhat, ...) {

#' @rdname MCMC-diagnostics
#' @export
#' @param ratio A vector of *ratios* of effective sample size estimates to
#' total sample size. See [neff_ratio()].
#'
mcmc_neff <- function(ratio, ..., size = NULL) {
check_ignored_arguments(...)
Expand Down Expand Up @@ -294,6 +312,93 @@ mcmc_neff_data <- function(ratio, ...) {
diagnostic_data_frame(ratio)
}

# monte carlo standard error -------------------------------------------

#' @rdname MCMC-diagnostics
#' @export
#'
mcmc_mcse <- function(ratio, ..., size = NULL) {
check_ignored_arguments(...)
data <- mcmc_mcse_data(ratio)

max_ratio <- max(ratio, na.rm = TRUE)
if (max_ratio < 1.25) {
additional_breaks <- numeric(0)
} else if (max_ratio < 1.5) {
additional_breaks <- 1.25
additional_labels <- "1.25"
} else {
additional_breaks <- seq(1.5, max_ratio, by = 0.5)
}
breaks <- c(0, 0.1, 0.25, 0.5, 0.75, 1, additional_breaks)

ggplot(
data,
mapping = aes_(
x = ~ value,
y = ~ parameter,
color = ~ rating,
fill = ~ rating)) +
geom_segment(
aes_(yend = ~ parameter, xend = -Inf),
na.rm = TRUE) +
diagnostic_points(size) +
vline_at(
c(0.1, 0.5, 1),
color = "gray",
linetype = 2,
size = 0.25) +
labs(y = NULL, x = expression(mcse/sd)) +
scale_fill_diagnostic("mcse") +
scale_color_diagnostic("mcse") +
scale_x_continuous(
breaks = breaks,
# as.character truncates trailing zeroes, while ggplot default does not
labels = as.character(breaks),
limits = c(0, max(1, max_ratio) + 0.05),
expand = c(0, 0)) +
bayesplot_theme_get() +
yaxis_text(FALSE) +
yaxis_title(FALSE) +
yaxis_ticks(FALSE)
}

#' @rdname MCMC-diagnostics
#' @export
mcmc_mcse_hist <- function(ratio, ..., binwidth = NULL, breaks = NULL) {
check_ignored_arguments(...)
data <- mcmc_mcse_data(ratio)

ggplot(
data,
mapping = aes_(
x = ~ value,
color = ~ rating,
fill = ~ rating)) +
geom_histogram(
size = .25,
na.rm = TRUE,
binwidth = binwidth,
breaks = breaks) +
scale_color_diagnostic("mcse") +
scale_fill_diagnostic("mcse") +
labs(x = expression(mcse/sd), y = NULL) +
dont_expand_y_axis(c(0.005, 0)) +
yaxis_title(FALSE) +
yaxis_text(FALSE) +
yaxis_ticks(FALSE) +
bayesplot_theme_get()
}

#' @rdname MCMC-diagnostics
#' @export
mcmc_mcse_data <- function(ratio, ...) {
check_ignored_arguments(...)
ratio <- drop_NAs_and_warn(new_mcse_ratio(ratio))
diagnostic_data_frame(ratio)
}



# autocorrelation ---------------------------------------------------------

Expand Down Expand Up @@ -354,7 +459,7 @@ mcmc_acf_bar <-
#'
#' @param x A numeric vector.
#' @param breaks A numeric vector of length two. The resulting factor variable
#' will have three levels ('low', 'ok', and 'high') corresponding to (
#' will have three levels ('low', 'mid', and 'high') corresponding to (
#' `x <= breaks[1]`, `breaks[1] < x <= breaks[2]`, `x > breaks[2]`).
#' @return A factor the same length as `x` with three levels.
#' @noRd
Expand All @@ -364,13 +469,19 @@ diagnostic_factor <- function(x, breaks, ...) {

diagnostic_factor.rhat <- function(x, breaks = c(1.05, 1.1)) {
cut(x, breaks = c(-Inf, breaks, Inf),
labels = c("low", "ok", "high"),
labels = c("low", "mid", "high"),
ordered_result = FALSE)
}

diagnostic_factor.neff_ratio <- function(x, breaks = c(0.1, 0.5)) {
cut(x, breaks = c(-Inf, breaks, Inf),
labels = c("low", "ok", "high"),
labels = c("low", "mid", "high"),
ordered_result = FALSE)
}

diagnostic_factor.mcse_ratio <- function(x, breaks = c(0.05, 0.1)) {
cut(x, breaks = c(-Inf, breaks, Inf),
labels = c("low", "mid", "high"),
ordered_result = FALSE)
}

Expand Down Expand Up @@ -411,17 +522,17 @@ diagnostic_points <- function(size = NULL) {

# Functions wrapping around scale_color_manual() and scale_fill_manual(), used to
# color the intervals by rhat value
scale_color_diagnostic <- function(diagnostic = c("rhat", "neff")) {
scale_color_diagnostic <- function(diagnostic = c("rhat", "neff", "mcse")) {
d <- match.arg(diagnostic)
diagnostic_color_scale(d, aesthetic = "color")
}

scale_fill_diagnostic <- function(diagnostic = c("rhat", "neff")) {
scale_fill_diagnostic <- function(diagnostic = c("rhat", "neff", "mcse")) {
d <- match.arg(diagnostic)
diagnostic_color_scale(d, aesthetic = "fill")
}

diagnostic_color_scale <- function(diagnostic = c("rhat", "neff_ratio"),
diagnostic_color_scale <- function(diagnostic = c("rhat", "neff_ratio", "mcse_ratio"),
aesthetic = c("color", "fill")) {
diagnostic <- match.arg(diagnostic)
aesthetic <- match.arg(aesthetic)
Expand All @@ -437,14 +548,17 @@ diagnostic_color_scale <- function(diagnostic = c("rhat", "neff_ratio"),
)
}

diagnostic_colors <- function(diagnostic = c("rhat", "neff_ratio"),
diagnostic_colors <- function(diagnostic = c("rhat", "neff_ratio", "mcse_ratio"),
aesthetic = c("color", "fill")) {
diagnostic <- match.arg(diagnostic)
aesthetic <- match.arg(aesthetic)
color_levels <- c("light", "mid", "dark")
if (diagnostic == "neff_ratio") {
color_levels <- rev(color_levels)
}
if (diagnostic == "mcse_ratio") {
color_levels <- color_levels
}
if (aesthetic == "color") {
color_levels <- paste0(color_levels, "_highlight")
}
Expand All @@ -455,19 +569,24 @@ diagnostic_colors <- function(diagnostic = c("rhat", "neff_ratio"),
aesthetic = aesthetic,
color_levels = color_levels,
color_labels = color_labels,
values = set_names(get_color(color_levels), c("low", "ok", "high")))
values = set_names(get_color(color_levels), c("low", "mid", "high")))
}

diagnostic_color_labels <- list(
rhat = c(
low = expression(hat(R) <= 1.05),
ok = expression(hat(R) <= 1.10),
mid = expression(hat(R) <= 1.10),
high = expression(hat(R) > 1.10)
),
neff_ratio = c(
low = expression(N[eff] / N <= 0.1),
ok = expression(N[eff] / N <= 0.5),
mid = expression(N[eff] / N <= 0.5),
high = expression(N[eff] / N > 0.5)
),
mcse_ratio = c(
low = expression(mcse / sd <= 0.05),
mid = expression(mcse / sd <= 0.1),
high = expression(mcse / sd > 0.1)
)
)

Expand Down Expand Up @@ -662,3 +781,30 @@ as_neff_ratio <- function(x) {
as_neff_ratio(NextMethod())
}

new_mcse_ratio <- function(x) {
# Convert a 1-d arrays to a vectors
if (is.array(x) && length(dim(x)) == 1) {
x <- as.vector(x)
}
as_mcse_ratio(validate_mcse_ratio(x))
}

validate_mcse_ratio <- function(x) {
stopifnot(is.numeric(x), !is.list(x), !is.array(x))
if (any(x < 0, na.rm = TRUE)) {
abort("All mcse ratios must be positive.")
}
x
}

as_mcse_ratio <- function(x) {
structure(x, class = c("mcse_ratio", "numeric"), names = names(x))
}

#' Indexing method -- needed so that sort, etc. don't strip names.
#' @export
#' @keywords internal
#' @noRd
`[.mcse_ratio` <- function (x, i, j, drop = TRUE, ...) {
as_mcse_ratio(NextMethod())
}
Loading