Skip to content

Commit

Permalink
Merge branch 'feature/restructure' of https://github.com/Boehringer-I…
Browse files Browse the repository at this point in the history
…ngelheim/BayesianMCPMod into feature/restructure
  • Loading branch information
Xyarz committed Oct 9, 2023
2 parents 6b2beb0 + 13ebcb2 commit 489e049
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 78 deletions.
58 changes: 32 additions & 26 deletions R/modelling.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ getModelFitSimple <- function (
resp = summary.postList(posterior)[, 1],
S = diag(summary.postList(posterior)[, 2]^2),
model = model,
type = doFit_ID$VALUES$FIT_TYPE,
type = "general",
bnds = DoseFinding::defBnds(mD = max(dose_levels))[[model]])

pred_vals <- stats::predict(fit, predType = "ls-means")
Expand Down Expand Up @@ -190,31 +190,37 @@ predictModelFit <- function (

pred_vals <- switch (
model_fit$model,
"emax" = {DoseFinding::emax(doses,
model_fit$coeffs["e0"],
model_fit$coeffs["eMax"],
model_fit$coeffs["ed50"])},
"sigEmax" = {DoseFinding::sigEmax(doses,
model_fit$coeffs["e0"],
model_fit$coeffs["eMax"],
model_fit$coeffs["ed50"],
model_fit$coeffs["h"])},
"exponential" = {DoseFinding::exponential(doses,
model_fit$coeffs["e0"],
model_fit$coeffs["e1"],
model_fit$coeffs["delta"])},
"quadratic" = {DoseFinding::quadratic(doses,
model_fit$coeffs["e0"],
model_fit$coeffs["b1"],
model_fit$coeffs["b2"])},
"linear" = {DoseFinding::linear(doses,
model_fit$coeffs["e0"],
model_fit$coeffs["delta"])},
"logistic" = {DoseFinding::logistic(doses,
model_fit$coeffs["e0"],
model_fit$coeffs["eMax"],
model_fit$coeffs["ed50"],
model_fit$coeffs["delta"])},
"emax" = {DoseFinding::emax(
doses,
model_fit$coeffs["e0"],
model_fit$coeffs["eMax"],
model_fit$coeffs["ed50"])},
"sigEmax" = {DoseFinding::sigEmax(
doses,
model_fit$coeffs["e0"],
model_fit$coeffs["eMax"],
model_fit$coeffs["ed50"],
model_fit$coeffs["h"])},
"exponential" = {DoseFinding::exponential(
doses,
model_fit$coeffs["e0"],
model_fit$coeffs["e1"],
model_fit$coeffs["delta"])},
"quadratic" = {DoseFinding::quadratic(
doses,
model_fit$coeffs["e0"],
model_fit$coeffs["b1"],
model_fit$coeffs["b2"])},
"linear" = {DoseFinding::linear(
doses,
model_fit$coeffs["e0"],
model_fit$coeffs["delta"])},
"logistic" = {DoseFinding::logistic(
doses,
model_fit$coeffs["e0"],
model_fit$coeffs["eMax"],
model_fit$coeffs["ed50"],
model_fit$coeffs["delta"])},
{stop(GENERAL$ERROR$MODEL_OPTIONS)})

return (pred_vals)
Expand Down
144 changes: 101 additions & 43 deletions R/plot.R
Original file line number Diff line number Diff line change
@@ -1,53 +1,111 @@
plot.estMod <- function (
plot.modelFits <- function (

est_mod
# dose_levels
# posteriors = posterior_linear[[2]]
model_fits,
CrI = FALSE,
gAIC = TRUE,
avg_fit = TRUE

) {

model <- est_mod$model
theta <- est_mod$fit$solution

dose <- seq(min(dose_levels), max(dose_levels), length.out = 1e3)

switch(model,
"emax" = {
resp_expr <- quote(theta[1] + (theta[2] * dose) / (theta[3] + dose))},
"sigEmax" = {
resp_expr <- quote(theta[1] + (theta[2] * dose^theta[4]) / (theta[3]^theta[4] + dose^theta[4]))},
"exponential" = {
resp_expr <- quote(theta[1] + theta[2] * (exp(dose / theta[3]) - 1))},
"quadratic" = {
resp_expr <- quote(theta[1] + theta[2] * dose + theta[3] * dose^2)},
"linear" = {
resp_expr <- quote(theta[1] + theta[2] * dose)},
"logistic" = {
resp_expr <- quote(theta[1] + theta[2] / (1 + exp((theta[3] - dose) / theta[4])))},
{
stop(GENERAL$ERROR$MODEL_OPTIONS)}
)

df <- data.frame(dose = dose, response = eval(resp_expr))

data.frame(dose = dose_levels, obs = )

plt <- ggplot2::ggplot(data = df) +
ggplot2::geom_line(ggplot2::aes(dose, response)) +
ggplot2::geom_point()

return(plt)

}

plot.estMods <- function (
plot_resolution <- 1e3

dose_levels <- model_fits[[1]]$dose_levels
post_summary <- summary.postList(attr(model_fits, "posterior"))
doses <- seq(from = min(dose_levels),
to = max(dose_levels), length.out = plot_resolution)

preds_models <- sapply(model_fits, predictModelFit, doses = doses)
model_names <- names(model_fits)

if (avg_fit) {

mod_weigts <- sapply(model_fits, function (x) x$model_weight)
avg_mod <- preds_models %*% mod_weigts

preds_models <- cbind(preds_models, avg_mod)
model_names <- c(model_names, "averageModel")

est_mods
}

) {
gg_data <- data.frame(
dose_levels = rep(doses, length(model_names)),
fits = as.vector(preds_models),
models = rep(factor(model_names,
levels = c("linear", "emax", "exponential",
"sigEmax", "logistic", "quadratic",
"averageModel")),
each = plot_resolution))

plts <- lapply(est_mods, plot)
if (gAIC) {

g_AICs <- sapply(model_fits, function (x) x$gAIC)
label_gAUC <- paste("AIC:", round(g_AICs, digits = 1))

if (avg_fit) {

mod_weigts <- sort(mod_weigts, decreasing = TRUE)
paste_names <- names(mod_weigts) |>
gsub("exponential", "exp", x = _) |>
gsub("quadratic", "quad", x = _) |>
gsub("linear", "lin", x = _) |>
gsub("logistic", "log", x = _) |>
gsub("sigEmax", "sigE", x = _)

label_avg <- paste0(paste_names, "=", round(mod_weigts, 1),
collapse = ", ")
label_gAUC <- c(label_gAUC, label_avg)

}

}

plts <- ggplot2::ggplot() +
## Layout etc.
ggplot2::theme_bw() +
ggplot2::labs(x = "Dose",
y = "Model Fits") +
ggplot2::theme(panel.grid.major = ggplot2::element_blank(),
panel.grid.minor = ggplot2::element_blank()) +
## gAIC
{if (gAIC) {

ggplot2::geom_text(
data = data.frame(
models = unique(gg_data$models),
label = label_gAUC),
mapping = ggplot2::aes(label = label_gAUC),
x = -Inf, y = Inf, hjust = "inward", vjust = "inward",
size = 3)

}
} +
## Posterior Credible Intervals
{if (CrI) {

ggplot2::geom_errorbar(
data = data.frame(x = dose_levels,
ymin = post_summary[, 3],
ymax = post_summary[, 5]),
mapping = ggplot2::aes(x = x,
ymin = ymin,
ymax = ymax),
width = 0, alpha = 0.5)

}
} +
## Posterior Medians
ggplot2::geom_point(
data = data.frame(dose_levels = dose_levels,
fits = post_summary[, 4]),
mapping = ggplot2::aes(dose_levels, fits),
size = 2) +
## Fitted Models
ggplot2::geom_line(
data = gg_data,
mapping = ggplot2::aes(dose_levels, fits)) +
## Faceting
ggplot2::facet_wrap(~ models)

return (plts)

}
}
49 changes: 40 additions & 9 deletions R/posterior.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,56 @@
getPosterior <- function(

data,
prior_list
prior_list,
mu_hat = NULL,
sd_hat = NULL

) {

lapply(split(data, data$simulation), getPosteriorI, prior_list = prior_list)
posterior_list <- lapply(split(data, data$simulation), getPosteriorI,
prior_list = prior_list,
mu_hat = mu_hat,
sd_hat = sd_hat)

if (length(posterior_list) == 1) {

posterior_list <- posterior_list[[1]]

}

return (posterior_list)

}

getPosteriorI <- function(

data_i,
prior_list
prior_list,
mu_hat = NULL,
sd_hat = NULL

) {

anova_res <- lm(data_i$response ~ factor(data_i$dose) - 1) # take mean & var out in separate function
anova_mean <- summary(anova_res)$coefficients[, 1] #
anova_se <- summary(anova_res)$coefficients[, 2] #
if (is.null(mu_hat) && is.null(sd_hat)) {

anova_res <- lm(data_i$response ~ factor(data_i$dose) - 1)
mu_hat <- summary(anova_res)$coefficients[, 1]
sd_hat <- summary(anova_res)$coefficients[, 2]

} else if (!is.null(mu_hat) && !is.null(sd_hat)) {

stopifnot("m_hat length must match number of dose levels" =
length(prior_list) == length(mu_hat),
"sd_hat length must match number of dose levels" =
length(prior_list) == nrow(sd_hat))

} else {

stop ("Both mu_hat and S_hat must be provided.")

}

post_list <- mapply(RBesT::postmix, prior_list, m = anova_mean, se = anova_se)
post_list <- mapply(RBesT::postmix, prior_list, m = mu_hat, se = sd_hat)

names(post_list) <- c("Ctr", paste0("DG_", seq_along(post_list[-1])))
class(post_list) <- "postList"
Expand Down Expand Up @@ -55,11 +85,12 @@ getPostCombsI <- function (

summary.postList <- function (

post_list
post_list,
...

) {

summary_list <- lapply(post_list, summary)
summary_list <- lapply(post_list, summary, ...)
names(summary_list) <- names(post_list)
summary_tab <- do.call(rbind, summary_list)

Expand Down
2 changes: 2 additions & 0 deletions vignettes/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.html
*.R
Loading

0 comments on commit 489e049

Please sign in to comment.