diff --git a/R/posterior.R b/R/posterior.R index df37f5e..04a88d4 100644 --- a/R/posterior.R +++ b/R/posterior.R @@ -5,9 +5,8 @@ #' @param dose_names prior_list #' @param robustify_weight tbd #' -#' @export getPriorList <- function ( - + hist_data, dose_levels, dose_names = NULL, @@ -66,21 +65,34 @@ getPriorList <- function ( #' @param data tbd #' @param prior_list prior_list #' @param mu_hat tbd -#' @param sd_hat tbd +#' @param se_hat tbd #' #' @export getPosterior <- function( - data, + prior_list, + data = NULL, mu_hat = NULL, - sd_hat = NULL + se_hat = NULL ) { - posterior_list <- lapply(split(data, data$simulation), getPosteriorI, - prior_list = prior_list, - mu_hat = mu_hat, - sd_hat = sd_hat) + if (!is.null(mu_hat) && !is.null(se_hat) && is.null(data)) { + + posterior_list <- getPosteriorI( + data_i = NULL, + prior_list = prior_list, + mu_hat = mu_hat, + se_hat = se_hat) + + } else { + + posterior_list <- lapply(split(data, data$simulation), getPosteriorI, + prior_list = prior_list, + mu_hat = mu_hat, + se_hat = se_hat) + + } if (length(posterior_list) == 1) { @@ -94,33 +106,33 @@ getPosterior <- function( getPosteriorI <- function( - data_i, + data_i = NULL, prior_list, mu_hat = NULL, - sd_hat = NULL + se_hat = NULL ) { - if (is.null(mu_hat) && is.null(sd_hat)) { + if (is.null(mu_hat) && is.null(se_hat)) { anova_res <- stats::lm(data_i$response ~ factor(data_i$dose) - 1) mu_hat <- summary(anova_res)$coefficients[, 1] - sd_hat <- summary(anova_res)$coefficients[, 2] + se_hat <- summary(anova_res)$coefficients[, 2] - } else if (!is.null(mu_hat) && !is.null(sd_hat)) { + } else if (!is.null(mu_hat) && !is.null(se_hat)) { stopifnot("m_hat length must match number of dose levels" = length(prior_list) == length(mu_hat), - "sd_hat length must match number of dose levels" = - length(prior_list) == length(sd_hat)) + "se_hat length must match number of dose levels" = + length(prior_list) == length(se_hat)) } else { - stop ("Both mu_hat and sd_hat must be provided.") + stop ("Both mu_hat and se_hat must be provided.") } - post_list <- mapply(RBesT::postmix, prior_list, m = mu_hat, se = sd_hat) + post_list <- mapply(RBesT::postmix, prior_list, m = mu_hat, se = se_hat) if (is.null(names(prior_list))) { diff --git a/vignettes/analysis_normal.Rmd b/vignettes/analysis_normal.Rmd index cdd3535..94542e7 100644 --- a/vignettes/analysis_normal.Rmd +++ b/vignettes/analysis_normal.Rmd @@ -101,9 +101,14 @@ data_emax <- simulateData( mods = mods, true_model = "emax") -posterior <- getPosterior(prior=prior_list,data=data_emax, - mu_hat = new_trial$rslt, - sd_hat = new_trial$se) +posterior <- getPosterior(prior = prior_list, + data = data_emax) + +posterior <- getPosterior(prior = prior_list, + mu_hat = new_trial$rslt, + se_hat = new_trial$se) + +summary(posterior) ```