diff --git a/DESCRIPTION b/DESCRIPTION index 801cfdf37..c6d203c43 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: marginaleffects Title: Predictions, Comparisons, Slopes, Marginal Means, and Hypothesis Tests -Version: 0.15.1 +Version: 0.15.1.9000 Authors@R: c(person(given = "Vincent", family = "Arel-Bundock", @@ -171,6 +171,7 @@ Collate: 'conformal.R' 'datagrid.R' 'equivalence.R' + 'estimates.R' 'get_averages.R' 'get_coef.R' 'get_contrast_data.R' diff --git a/NAMESPACE b/NAMESPACE index 3a26c783d..a8b022f63 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -156,6 +156,7 @@ export(comparisons) export(datagrid) export(datagridcf) export(deltamethod) +export(estimates) export(expect_marginal_means) export(expect_margins) export(expect_predictions) diff --git a/NEWS.md b/NEWS.md index 0b9b108c4..fd390589c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,11 @@ # News +## dev + +New: + +* `estimates()` function provides a unified interface to fit a large number of models. All models are compatible with `marginaleffects` functions, so users can easily switch between modeling strategies by changing a single argument. + ## 0.15.1 * `hypotheses()`: The `FUN` argument handles `group` columns gracefully. diff --git a/R/estimates.R b/R/estimates.R new file mode 100644 index 000000000..2f38e2d3b --- /dev/null +++ b/R/estimates.R @@ -0,0 +1,230 @@ +estimates_dictionary_build <- function() { +text <- +'Model,Description,Package,Function +lm,Linear Model,stats,lm +nls,Nonlinear Least Squares,stats,nls +logit,Logistic,stats,glm +probit,Probit,stats,glm +beta,Beta Regression,betareg,betareg +betabinomial,Beta-Binomial,aod,betabin +glm,Generalized Linear Model,stats,glm +ologit,Ordered Logistic,MASS,polr +oprobit,Ordered Probit,MASS,polr +ologlog,Ordered Log-Log,MASS,polr +ocloglog,Ordered Complementary Log-Log,MASS,polr +ocauchit,Ordered Log-Log,MASS,polr +robustlm,Robust Linear,robustbase,lmrob +robustglm,Robust Generalized Linear,robustbase,glmrob +multinom,Multinomial Log-Linear,nnet,multinom +negbin,Negative Binomial,MASS,glm.nb +felm,Fixed Effects Linear Model,fixest,feols +fepoisson,Fixed Effects Poisson,fixest,fepois +feglm,Fixed Effects GLM,fixest,feglm +2sls,Two-Stage Least Squares,ivreg,ivreg +firth_logit,Firth Logitistic,logistf,logistf +firth_flic,Firth Logitistic with Intercept Correction,logistf,flac +firth_flac,Firth Logitistic with Added Covariate,logistf,flac +melm,Mixed-Effects Linear,lme4,lmer +meglm,Mixed-Effects Generalized Linear,lme4,glmer +truncreg,Truncated Gaussian Response,truncreg,truncreg +coxph,Cox Proportional Hazards,survival,coxph +quantreg,Quantile Regression,quantreg,rq +poisson0,Zero-Inflated Poisson,pscl,zeroinfl +negbin0,Zero-Inflated Negative Binomial,pscl,zeroinfl +geometric0,Zero-Inflated Geometric,pscl,zeroinfl +heckman,Heckman-Style Selection and Treatment Effect,sampleSelection,selection +heckit,Heckman-Style Selection and Treatment Effect,sampleSelection,heckit +gam,Generalized Additive Model,mgcv,gam +2sls_robust,Two-Stage Least Squares with Robust SEs,estimatr,iv_robust +' +out <- utils::read.csv( + text = text, + colClasses = c("character", "character", "character", "character")) +colnames(out) <- gsub("\\.$", "", colnames(out)) +for (i in 1:4) { + out[[i]] <- trimws(out[[i]]) +} +out <- out[order(out$Description), ] +class(out) <- c("estimates_dictionary", "data.frame") +row.names(out) <- NULL +return(out) +} + + +#' estimates dictionary +#' +#' @noRd +estimates_dictionary <- estimates_dictionary_build() + + +get_function_args <- function(name, pkg) { + insight::check_if_installed(pkg) + args <- names(formals(methods::getFunction(name, where = asNamespace(pkg)))) + args <- setdiff(args, c("formula", "data", "model", "fml", "...")) + args <- unique(args) + return(args) +} + + +print.estimates_dictionary <- function(x, ...) { + flag <- insight::check_if_installed("knitr", quietly = TRUE) + if (isTRUE(flag)) { + cat("\nAvailable models:") + print(knitr::kable(x, row.names = FALSE)) + } else { + print(x) + } +} + + +check_required_argument <- function(arg, pkg, fun, ...) { + if (!arg %in% names(list(...))) { + insight::format_error( + sprintf("The `%s` argument is required. Please read the documentation:", arg), + sprintf("?%s::%s ", pkg, fun) + ) + } +} + + +#' Fit statistical models to obtain parameter estimates +#' +#' This function offers a single point of entry for fitting many different statistical models. It provides a unified user interface for various model fitting functions, making it easier to switch between models and compare results. If the `model` argument is missing, the function returns a list of available models and the functions used under the hood to fit each model. If the `formula` argument is missing, a description of the model is printed with a list of extra arguments which can be passed to `estimates()` to change the model or estimation procedure. +#' +#' @param formula a formula specifying the model to fit +#' @param data a data frame containing the variables in the formula +#' @param model a character string specifying the model to fit. If missing, returns a list of available models and their corresponding functions. +#' @param ... additional arguments to be passed to the model fitting function +#' +#' @return a model object. These objects can differ from model to model, but they are all supported by `marginaleffects` functions like `predictions()`, `slopes()`, `comparisons()`, and `hypotheses()`. These objects can also be summarized in nice tables using the `modelsummary` package. +#' +#' @examples +#' estimates(gear ~ wt, data = mtcars, model = "lm") +#' +#' estimates(gear ~ wt, data = mtcars, model = "oprobit") +#' +#' estimates(gear ~ wt + (1 | cyl), data = mtcars, model = "melm") +#' +#' estimates(gear ~ wt, data = mtcars, model = "oprobit") |> +#' avg_slopes() +#' +#' @export +#' +estimates <- function(formula, data, model, ...) { + call_save <- utils::match.call() + # what models are available? + if (missing(model)) { + return(estimates_dictionary) + } else { + checkmate::assert_choice(model, estimates_dictionary$Model) + } + + fun_name <- estimates_dictionary[estimates_dictionary$Model == model, , drop = FALSE] + funargs <- get_function_args(fun_name$Function, fun_name$Package) + funargs <- c("formula", "data", "model", funargs) + + # what arguments are available? + if (missing(formula)) { + insight::format_error("The `formula` argument is required.") + } + + # fit the model and extract estimates + estimates <- do.call(fun_name$Function, c(list(formula = formula, data = data), list(...))) + estimates <- extract_estimates(estimates, model) + + return(estimates) +} +estimates <- function(formula, data, model, ...) { + call_save <- match.call() + # what models are available? + if (missing(model)) { + return(estimates_dictionary) + } else { + checkmate::assert_choice(model, estimates_dictionary$Model) + } + + fun_name <- subset(estimates_dictionary, Model == model) + funargs <- get_function_args(fun_name$Function, fun_name$Package) + funargs <- c("formula", "data", "model", funargs) + + # what arguments are available? + if (missing(formula)) { + msg <- sprintf(" +Model: %s +Package: %s +Function: %s +Documentation: ?%s::%s +Arguments: %s +", + fun_name$Description, fun_name$Package, fun_name$Function, fun_name$Package, fun_name$Function, paste(funargs, collapse = ", ")) + message(msg) + return(invisible(estimates_dictionary)) + } + + checkmate::assert_formula(formula) + + # missing data + if (missing(data) || !isTRUE(checkmate::check_data_frame(data, null.ok = FALSE))) { + insight::format_error("`data` must be a data.frame.") + } + + # fit arguments + args <- list(formula = formula, data = data) + args <- c(args, list(...)) + + # standardize argument names + if (model %in% c("felm", "fepoisson", "feglm")) { + args$fml <- args$formula + args$formula <- NULL + } else if (model == "logit") { + args$family = stats::binomial(link = "logit") + } else if (model == "probit") { + args$family = stats::binomial(link = "probit") + } else if (model == "poisson") { + args$family = stats::poisson() + } else if (model == "ologit") { + args$method = "logistic" + } else if (model == "oprobit") { + args$method = "probit" + } else if (model == "ologlog") { + args$method = "loglog" + } else if (model == "ocloglog") { + args$method = "cloglog" + } else if (model == "ocauchit") { + args$method = "cauchit" + } else if (model == "2sls") { + check_required_argument("instruments", "ivreg", "ivreg", ...) + } else if (model == "quantreg") { + check_required_argument("tau", "quantreg", "rq", ...) + } else if (model == "negbin0") { + args$dist <- "negbin" + } else if (model == "geometric") { + args$dist <- "geometric" + } else if (model == "heckman") { + check_required_argument("outcome", "sampleSelection", "selection") + args$outcome <- args$formula + args$formula <- NULL + } + + # convenience: ordinal responses must be factor + if (model %in% c("ologit", "oprobit", "ologlog", "ocloglog", "ocauchit")) { + dv <- as.character(as.list(formula)[[2]]) + if (!dv %in% colnames(data)) { + insight::format_error(sprintf("The dependent variable `%s` is not in the data.", dv)) + } + if (!is.factor(data[[dv]])) { + data[[dv]] <- factor(data[[dv]]) + args$data <- data + } + } + + FUN <- methods::getFunction(fun_name$Function, where = asNamespace(fun_name$Package)) + out <- do.call(FUN, args) + + if ("call" %in% names(out)) { + out$call <- call_save + } else if ("call" %in% names(attributes(out))) { + attributes(out)$call <- call_save + } + return(out) +} \ No newline at end of file diff --git a/data-raw/supported_models.csv b/data-raw/supported_models.csv index b74a4a8b8..ca04fa08f 100644 --- a/data-raw/supported_models.csv +++ b/data-raw/supported_models.csv @@ -61,7 +61,6 @@ plm,plm,TRUE,TRUE,TRUE,TRUE,TRUE,TRUE,U,U phylolm,phylolm,TRUE,TRUE,,,,,, phylolm,phyloglm,TRUE,TRUE,,,,,, pscl,hurdle,TRUE,TRUE,,,TRUE,U,TRUE,FALSE -pscl,hurdle,TRUE,TRUE,,,TRUE,U,TRUE,FALSE pscl,zeroinfl,TRUE,TRUE,TRUE,TRUE,TRUE,U,TRUE,TRUE quantreg,rq,TRUE,TRUE,TRUE,TRUE,U,U,TRUE,TRUE Rchoice,hetprob,TRUE,TRUE,,,,,, diff --git a/man/estimates.Rd b/man/estimates.Rd new file mode 100644 index 000000000..8711c4600 --- /dev/null +++ b/man/estimates.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/estimates.R +\name{estimates} +\alias{estimates} +\title{Fit statistical models to obtain parameter estimates} +\usage{ +estimates(formula, data, model, ...) +} +\arguments{ +\item{formula}{a formula specifying the model to fit} + +\item{data}{a data frame containing the variables in the formula} + +\item{model}{a character string specifying the model to fit. If missing, returns a list of available models and their corresponding functions.} + +\item{...}{additional arguments to be passed to the model fitting function} +} +\value{ +a model object. These objects can differ from model to model, but they are all supported by \code{marginaleffects} functions like \code{predictions()}, \code{slopes()}, \code{comparisons()}, and \code{hypotheses()}. These objects can also be summarized in nice tables using the \code{modelsummary} package. +} +\description{ +This function offers a single point of entry for fitting many different statistical models. It provides a unified user interface for various model fitting functions, making it easier to switch between models and compare results. If the \code{model} argument is missing, the function returns a list of available models and the functions used under the hood to fit each model. If the \code{formula} argument is missing, a description of the model is printed with a list of extra arguments which can be passed to \code{estimates()} to change the model or estimation procedure. +} +\examples{ +estimates(gear ~ wt, data = mtcars, model = "lm") + +estimates(gear ~ wt, data = mtcars, model = "oprobit") + +estimates(gear ~ wt + (1 | cyl), data = mtcars, model = "melm") + +estimates(gear ~ wt, data = mtcars, model = "oprobit") |> + avg_slopes() + +}