Skip to content

Commit

Permalink
Support for flexsurv models (#781)
Browse files Browse the repository at this point in the history
* Inital work to support flexsurv models

* Add model to sanity_model.R

* Add models to CSV file

* Allow for multiple predictions per covariate pattern

* Add class prediction types and whitelist valid dot args

* flexsurv: finalize

* bump

---------

Co-authored-by: Vincent Arel-Bundock <[email protected]>
  • Loading branch information
mattwarkentin and vincentarelbundock authored Apr 6, 2024
1 parent dcbd595 commit 61eeb45
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 16 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Suggests:
equivalence,
estimatr,
fixest,
flexsurv,
fmeffects,
fontquiver,
future,
Expand Down Expand Up @@ -231,6 +232,7 @@ Collate:
'methods_dataframe.R'
'methods_dbarts.R'
'methods_fixest.R'
'methods_flexsurv.R'
'methods_gamlss.R'
'methods_glmmTMB.R'
'methods_glmx.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ S3method(get_predict,coxph)
S3method(get_predict,crch)
S3method(get_predict,default)
S3method(get_predict,fixest)
S3method(get_predict,flexsurvreg)
S3method(get_predict,gamlss)
S3method(get_predict,glimML)
S3method(get_predict,glm)
Expand Down Expand Up @@ -127,6 +128,7 @@ S3method(set_coef,betareg)
S3method(set_coef,crch)
S3method(set_coef,data.frame)
S3method(set_coef,default)
S3method(set_coef,flexsurvreg)
S3method(set_coef,gamlss)
S3method(set_coef,glimML)
S3method(set_coef,glm)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ Breaking changes:
* `datagrid()` no longer includes the response variable by default when it is not explicitly specified by the user. Use the new `response` argument to include it.
* `datagrid(grid_type="balanced")` returns factors ordered by level rather than by order of appearance in the original data.

New modeling packages supported:

* `flexsurv`: Thanks to @mattwarkentin for code contributions in PR #781. https://cran.r-project.org/web/packages/flexsurv/index.html

New:

* `wts=TRUE` tries to retrieves weights used in a weighted fit such as `lm()` with the `weights` argument or a model fitted using the `survey` package. Thanks to @ngreifer for feature request
Expand Down
46 changes: 46 additions & 0 deletions R/methods_flexsurv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#' @rdname set_coef
#' @export
set_coef.flexsurvreg <- function(model, coefs, ...) {
out <- model
out$res[, 1] <- coefs
out$coefficients <- coefs
return(out)
}


#' @rdname get_predict
#' @export
get_predict.flexsurvreg <- function(model, newdata, type, ...) {
preds <- stats::predict(
object = model,
newdata = newdata,
type = type,
...
)

if (ncol(preds) == 1L) {
if (names(preds) == '.pred') {
gp <- unlist(lapply(preds$.pred, function(x) {x[, 1, drop = TRUE]}))
val <- unlist(lapply(preds$.pred, function(x) {x[, 2, drop = TRUE]}))

out <- data.frame(
rowid = seq_len(nrow(preds)),
group = as.vector(gp),
estimate = as.vector(val)
)
return(out)
}
out <- data.frame(
rowid = seq_len(nrow(preds)),
estimate = as.vector(preds[, 1, drop = TRUE])
)
return(out)
}

out <- data.frame(
rowid = seq_len(nrow(preds)),
group = as.vector(preds[, 1, drop = TRUE]),
estimate = as.vector(preds[, 2, drop = TRUE])
)
out
}
3 changes: 3 additions & 0 deletions R/sanity_dots.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ sanity_dots <- function(model, calling_function = NULL, ...) {
valid[["bife"]] <- c("alpha_new", "corrected") # nlme::lme
valid[["process_error"]] <- # mvgam::mvgam

# flexsurv
valid[["flexsurvreg"]] <- c("times", "p", "start")

white_list <- c(
"conf.int", "modeldata", "internal_call", "df",
"transform", "comparison", "side", "delta", "null", "equivalence", "draw",
Expand Down
1 change: 1 addition & 0 deletions R/sanity_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ sanity_model_supported_class <- function(model) {
c("clmm2", "clm2"),
"coxph",
"crch",
"flexsurvreg", # package: flexsurv
"fixest",
"flic",
"flac",
Expand Down
9 changes: 9 additions & 0 deletions R/type_dictionary.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ hxlr,scale
hxlr,density
ivpml,pr
ivpml,xb
flexsurvreg,response,
flexsurvreg,mean,
flexsurvreg,link,
flexsurvreg,lp,
flexsurvreg,linear,
flexsurvreg,rmst,
flexsurvreg,survival,
flexsurvreg,hazard,
flexsurvreg,cumhaz,
fixest,invlink(link)
fixest,response
fixest,link
Expand Down
6 changes: 4 additions & 2 deletions data-raw/supported_models.csv
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ DCchoice,oohbchoice,TRUE,TRUE,,,,,,
estimatr,lm_lin,TRUE,TRUE,,,,,,
estimatr,lm_robust,TRUE,TRUE,TRUE,TRUE,TRUE,U,TRUE,TRUE
estimatr,iv_robust,TRUE,TRUE,TRUE,TRUE,U,U,U,U
flexsurv,flexsurvreg,TRUE,TRUE,,,,,,
flexsurv,flexsurvspline,TRUE,TRUE,,,,,,
fixest,feols,TRUE,TRUE,TRUE,TRUE,U,U,U,U
fixest,feglm,TRUE,TRUE,,,U,U,U,U
fixest,fenegbin,TRUE,TRUE,,,U,U,U,U
Expand All @@ -46,8 +48,8 @@ MASS,glmmPQL,TRUE,TRUE,,,U,U,TRUE,TRUE
MASS,glm.nb,TRUE,TRUE,TRUE,TRUE,TRUE,TRUE,TRUE,TRUE
MASS,polr,TRUE,TRUE,TRUE,TRUE,FALSE,FALSE,TRUE,TRUE
MASS,rlm,TRUE,TRUE,,,TRUE,TRUE,TRUE,TRUE
mclogit,mblogit,TRUE,TRUE,,,U,U,U,U
mclogit,mclogit,TRUE,TRUE,,,U,U,U,U
mclogit,mblogit,TRUE,TRUE,,,U,U,U,U
mclogit,mclogit,TRUE,TRUE,,,U,U,U,U
MCMCglmm,MCMCglmm,TRUE,TRUE,U,U,U,U,U,U
mgcv,gam,TRUE,TRUE,,,U,U,TRUE,TRUE
mgcv,bam,TRUE,TRUE,,,U,U,TRUE,FALSE
Expand Down
12 changes: 12 additions & 0 deletions inst/tinytest/test-pkg-flexsurv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
source("helpers.R")
if (!requiet("flexsurv")) exit_file("flexsurv")


mod <- flexsurvreg(formula = Surv(futime, fustat) ~ age + ecog.ps, data = ovarian, dist= "gengamma")
x <- avg_slopes(mod)
expect_inherits(x, "slopes")
x <- predictions(mod)
expect_inherits(x, "predictions")
x <- comparisons(mod)
expect_inherits(x, "comparisons")

2 changes: 1 addition & 1 deletion man/comparisons.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 10 additions & 6 deletions man/get_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/plot_predictions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/predictions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 8 additions & 4 deletions man/set_coef.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/slopes.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 61eeb45

Please sign in to comment.