Skip to content

Commit

Permalink
Issue #1277 get_draws (#1283)
Browse files Browse the repository at this point in the history
* issue #1277 posterior_draws() -> get_draws()

* minor
  • Loading branch information
vincentarelbundock authored Nov 24, 2024
1 parent 0fc9b90 commit 0b1d3c3
Show file tree
Hide file tree
Showing 23 changed files with 428 additions and 436 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ Collate:
'get_contrast_data_logical.R'
'get_contrast_data_numeric.R'
'get_contrasts.R'
'get_draws.R'
'get_group_names.R'
'get_hypothesis.R'
'get_jacobian.R'
Expand Down Expand Up @@ -284,7 +285,6 @@ Collate:
'plot_comparisons.R'
'plot_predictions.R'
'plot_slopes.R'
'posterior_draws.R'
'predictions.R'
'print.R'
'recall.R'
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ export(expect_margins)
export(expect_predictions)
export(expect_slopes)
export(get_coef)
export(get_draws)
export(get_group_names)
export(get_model_matrix)
export(get_predict)
Expand All @@ -213,7 +214,6 @@ export(plot_comparisons)
export(plot_predictions)
export(plot_slopes)
export(posterior_draws)
export(posteriordraws)
export(predictions)
export(set_coef)
export(slopes)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Misc:
* Be less strict about combining columns of different types. This allows us to handle types like `haven_labelled`. Thanks to @mwindzio for report #1238.
* In `lme4` and `glmmTMB` models, warnings are now silenced when the user specifically passes `re.form=NULL`. Thanks to @mattansb for the feature request.
* New startup message appears once per 24hr period and can be suppressed using `options(marginaleffects_startup_message = FALSE)`.
* `posterior_draws()` is renamed `get_draws()` because it also applies to bootstrap and simulation-based inference draws.
* `get_coef()` and `get_vcov()` are now documented on the main website, as they are useful helper functions.

## 0.23.0

Expand Down
20 changes: 10 additions & 10 deletions R/get_coef.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
#' Get a named vector of coefficients from a model object (internal function)
#'
#' Get a named vector of coefficients from a model object
#'
#' @inheritParams slopes
#' @return A named vector of coefficients. The names must match those of the variance matrix.
#' @rdname get_coef
#' @keywords internal
#' @export
get_coef <- function (model, ...) {
UseMethod("get_coef", model)
get_coef <- function(model, ...) {
UseMethod("get_coef", model)
}

#' @rdname get_coef
#' @export
get_coef.default <- function(model, ...) {
## faster
# out <- stats::coef(model)
## faster
# out <- stats::coef(model)

# more general
out <- insight::get_parameters(model, component = "all")
# more general
out <- insight::get_parameters(model, component = "all")

out <- stats::setNames(out$Estimate, out$Parameter)
return(out)
out <- stats::setNames(out$Estimate, out$Parameter)
return(out)
}
134 changes: 134 additions & 0 deletions R/get_draws.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#' Extract Posterior Draws or Bootstrap Resamples from `marginaleffects` Objects
#'
#' @param x An object produced by a `marginaleffects` package function, such as `predictions()`, `avg_slopes()`, `hypotheses()`, etc.
#' @param shape string indicating the shape of the output format:
#' * "long": long format data frame
#' * "DxP": Matrix with draws as rows and parameters as columns
#' * "PxD": Matrix with draws as rows and parameters as columns
#' * "rvar": Random variable datatype (see `posterior` package documentation).
#' @return A data.frame with `drawid` and `draw` columns.
#' @export
get_draws <- function(x, shape = "long") {
checkmate::assert_choice(shape, choices = c("long", "DxP", "PxD", "rvar"))

# tidy.comparisons() sometimes already saves draws in a nice long format
draws <- attr(x, "posterior_draws")
if (inherits(draws, "posterior_draws")) {
return(draws)
}

if (is.null(attr(x, "posterior_draws"))) {
warning('This object does not include a "posterior_draws" attribute. The `posterior_draws` function only supports bayesian models produced by the `marginaleffects` or `predictions` functions of the `marginaleffects` package.',
call. = FALSE)
return(x)
}

if (nrow(draws) != nrow(x)) {
stop("The number of parameters in the object does not match the number of parameters for which posterior draws are available.", call. = FALSE)
}

if (shape %in% c("PxD", "DxP")) {
row.names(draws) <- paste0("b", seq_len(nrow(draws)))
colnames(draws) <- paste0("draw", seq_len(ncol(draws)))
}

if (shape == "PxD") {
return(draws)
}

if (shape == "DxP") {
return(t(draws))
}

if (shape == "rvar") {
insight::check_if_installed("posterior")
draws <- t(draws)
if (!is.null(attr(x, "nchains"))) {
x[["rvar"]] <- posterior::rvar(draws, nchains = attr(x, "nchains"))
} else {
x[["rvar"]] <- posterior::rvar(draws)
}
return(x)
}

if (shape == "long") {
draws <- data.table(draws)
setnames(draws, as.character(seq_len(ncol(draws))))
for (v in colnames(x)) {
draws[[v]] <- x[[v]]
}
out <- melt(
draws,
id.vars = colnames(x),
variable.name = "drawid",
value.name = "draw")
cols <- unique(c("drawid", "draw", "rowid", colnames(out)))
cols <- intersect(cols, colnames(out))
setcolorder(out, cols)
data.table::setDF(out)
return(out)
}
}


average_draws <- function(data, index, draws, byfun = NULL) {
insight::check_if_installed("collapse", minimum_version = "1.9.0")

w <- data[["marginaleffects_wts_internal"]]
if (all(is.na(w))) {
w <- NULL
}

if (is.null(index)) {
index <- intersect(colnames(data), "type")
}

if (length(index) > 0) {
g <- collapse::GRP(data, by = index)

if (is.null(byfun)) {
draws <- collapse::fmean(
draws,
g = g,
w = w,
drop = FALSE)
} else {
draws <- collapse::BY(
draws,
g = g,
FUN = byfun,
drop = FALSE)
}
out <- data.table(
g[["groups"]],
average = collapse::dapply(draws, MARGIN = 1, FUN = collapse::fmedian))
} else {
if (is.null(byfun)) {
draws <- collapse::fmean(
draws,
w = w,
drop = FALSE)
} else {
draws <- collapse::BY(
draws,
g = g,
FUN = byfun,
drop = FALSE)
}
out <- data.table(average = collapse::dapply(draws, MARGIN = 1, FUN = collapse::fmedian))
}

setnames(out, old = "average", new = "estimate")
attr(out, "posterior_draws") <- draws
return(out)
}




#' alias to `get_draws()` for backward compatibility with JJSS
#'
#' @inherit posterior_draws
#' @keywords internal
#' @export
posterior_draws <- get_draws
2 changes: 1 addition & 1 deletion R/get_hypothesis.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ get_hypothesis <- function(

if (is.function(hypothesis)) {
if (!is.null(draws)) {
msg <- "The `hypothesis` argument does not support function for models with draws. You can use `posterior_draws()` to extract draws and manipulate them directly instead."
msg <- "The `hypothesis` argument does not support function for models with draws. You can use `get_draws()` to extract draws and manipulate them directly instead."
insight::format_error(msg)
}
if ("rowid" %in% colnames(x) && "rowid" %in% colnames(newdata)) {
Expand Down
Loading

0 comments on commit 0b1d3c3

Please sign in to comment.