Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
drizopoulos committed Apr 10, 2024
1 parent 512889c commit 42ef08d
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 24 deletions.
12 changes: 7 additions & 5 deletions R/basic_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ crisk_setup <- function (data, statusVar, censLevel, nameStrata = "strata",
}

predict.jm <- function (object, newdata = NULL, newdata2 = NULL,
times = NULL, times_per_id = FALSE,
times = NULL, all_times = FALSE, times_per_id = FALSE,
process = c("longitudinal", "event"),
type_pred = c("response", "link"),
type = c("subject_specific", "mean_subject"),
Expand Down Expand Up @@ -740,8 +740,8 @@ predict.jm <- function (object, newdata = NULL, newdata2 = NULL,
n_mcmc, parallel, cores, seed, use_Y)
if (process == "longitudinal") {
predict_Long(object, components_newdata, newdata, newdata2, times,
times_per_id, type, type_pred, level, return_newdata,
return_mcmc)
all_times, times_per_id, type, type_pred, level,
return_newdata, return_mcmc)
} else {
predict_Event(object, components_newdata, newdata, newdata2, times,
times_per_id, level, return_newdata, return_mcmc)
Expand Down Expand Up @@ -1060,7 +1060,7 @@ rc_setup <- function(rc_data, trm_data,
}

predict.jmList <- function (object, weights, newdata = NULL, newdata2 = NULL,
times = NULL, times_per_id = FALSE,
times = NULL, all_times = FALSE, times_per_id = FALSE,
process = c("longitudinal", "event"),
type_pred = c("response", "link"),
type = c("subject_specific", "mean_subject"),
Expand Down Expand Up @@ -1211,6 +1211,7 @@ predict.jmList <- function (object, weights, newdata = NULL, newdata2 = NULL,
preds <-
parallel::mclapply(object, predict, newdata = newdata,
newdata2 = newdata2, times = times,
all_times = all_times,
times_per_id = times_per_id,
process = process, type_pred = type_pred,
type = type, level = level, n_samples = n_samples,
Expand All @@ -1222,6 +1223,7 @@ predict.jmList <- function (object, weights, newdata = NULL, newdata2 = NULL,
preds <-
parallel::parLapply(cl, object, predict, newdata = newdata,
newdata2 = newdata2, times = times,
all_times = all_times,
times_per_id = times_per_id,
process = process, type_pred = type_pred,
type = type, level = level, n_samples = n_samples,
Expand All @@ -1233,7 +1235,7 @@ predict.jmList <- function (object, weights, newdata = NULL, newdata2 = NULL,
preds <-
lapply(object, predict, newdata = newdata,
newdata2 = newdata2, times = times,
times_per_id = times_per_id,
all_times = all_times, times_per_id = times_per_id,
process = process, type_pred = type_pred,
type = type, level = level, n_samples = n_samples,
n_mcmc = n_mcmc, return_newdata = return_newdata,
Expand Down
3 changes: 3 additions & 0 deletions R/help_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ bdiag2 <- function (mlist, off_diag_val = 1e-05, which_independent = NULL) {
out[ind, ind] <- mlist[[i]]
}
if (!is.null(which_independent)) {
if (length(which_independent) == 1 && which_independent == "all") {
which_independent <- t(combn(length(mlist), 2))
}
if (!is.matrix(which_independent) || ncol(which_independent) != 2) {
stop("'which_independent' must a matrix with two columns.\n")
}
Expand Down
17 changes: 12 additions & 5 deletions R/predict_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,8 @@ get_components_newdata <- function (object, newdata, n_samples, n_mcmc,
}

predict_Long <- function (object, components_newdata, newdata, newdata2, times,
times_per_id, type, type_pred, level, return_newdata,
return_mcmc) {
all_times, times_per_id, type, type_pred, level,
return_newdata, return_mcmc) {
# Predictions for newdata
newdataL <- if (!is.data.frame(newdata)) newdata[["newdataL"]] else newdata
betas <- components_newdata$mcmc[["betas"]]
Expand Down Expand Up @@ -736,7 +736,7 @@ predict_Long <- function (object, components_newdata, newdata, newdata2, times,
}
t_max <- max(object$model_data$Time_right)
test <- sapply(last_times, function (lt, tt) all(tt <= lt), tt = times)
if (any(test)) {
if (any(test) && !all_times) {
stop("according to the definition of argument 'times', for some ",
"subjects the last available time is\n\t larger than the ",
"maximum time to predict; redefine 'times' accordingly.")
Expand All @@ -747,8 +747,15 @@ predict_Long <- function (object, components_newdata, newdata, newdata2, times,
MoreArgs = list(tm = t_max))

} else {
f <- function (lt, tt, tm) c(lt, sort(tt[tt > lt & tt <= tm]))
times <- lapply(last_times, f, tt = times, tm = t_max)
f <- function (lt, tt, tm, all_times) {
if (all_times) {
sort(tt[tt <= tm])
} else {
c(lt, sort(tt[tt > lt & tt <= tm]))
}
}
times <- lapply(last_times, f, tt = times, tm = t_max,
all_times = all_times)
}
n_times <- sapply(times, length)
newdata2 <- newdataL
Expand Down
3 changes: 1 addition & 2 deletions man/jm.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ tv(x, knots = NULL, ord = 2L)
\item{time_var}{a \code{character} string indicating the time variable in the mixed-effects model(s).}
\item{recurrent}{a \code{character} string indicating "calendar" or "gap" timescale to fit a recurrent event model.}
\item{functional_forms}{a \code{list} of formulas. Each formula corresponds to one longitudinal outcome and specifies the association structure between that outcome and the survival submodel as well as any interaction terms between the components of the longitudinal outcome and the survival submodel. See \bold{Examples}.}
\item{which_independent}{a numeric indicator matrix denoting which outcomes are independent. Only relevant in
joint models with multiple longitudinal outcomes.}
\item{which_independent}{a numeric indicator matrix denoting which outcomes are independent. It can also be the character string \code{"all"} in which case all longitudinal outcomes are assumed independent. Only relevant in joint models with multiple longitudinal outcomes.}
\item{data_Surv}{the \code{data.frame} used to fit the Cox/AFT survival submodel.}
\item{id_var}{a \code{character} string indicating the id variable in the survival submodel.}
\item{priors}{a named \code{list} of user-specified prior parameters:
Expand Down
9 changes: 7 additions & 2 deletions man/predict.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Predict method for object of class \code{"jm"}.
\usage{

\method{predict}{jm}(object, newdata = NULL, newdata2 = NULL, times = NULL,
times_per_id = FALSE, process = c("longitudinal", "event"),
all_times = FALSE, times_per_id = FALSE,
process = c("longitudinal", "event"),
type_pred = c("response", "link"),
type = c("subject_specific", "mean_subject"),
level = 0.95, return_newdata = FALSE, use_Y = TRUE, return_mcmc = FALSE,
Expand All @@ -36,7 +37,8 @@ Predict method for object of class \code{"jm"}.
\dots)

\method{predict}{jmList}(object, weights, newdata = NULL, newdata2 = NULL,
times = NULL, times_per_id = FALSE, process = c("longitudinal", "event"),
times = NULL, all_times = FALSE, times_per_id = FALSE,
process = c("longitudinal", "event"),
type_pred = c("response", "link"),
type = c("subject_specific", "mean_subject"),
level = 0.95, return_newdata = FALSE,
Expand All @@ -54,6 +56,9 @@ Predict method for object of class \code{"jm"}.

\item{times}{a numeric vector of future times to calculate predictions.}

\item{all_times}{logical; if \code{TRUE} predictions for the longitudinal outcomes are calculated for all the times
given in the \code{times} argumet, not only the ones after the last longitudinal measurement.}.

\item{times_per_id}{logical; if \code{TRUE} the \code{times} argument is a vector of times equal to the number of
subjects in \code{newdata}.}

Expand Down
14 changes: 4 additions & 10 deletions src/JMbayes2_D.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,13 @@ mat propose_L (const mat &L, const vec &scale, const uvec &upper_part,
uword n = L.n_rows;
for (uword j = 0; j < n; ++j) {
vec ll = proposed_L.col(j);
proposed_L(j, j) = sqrt(1 - dot(ll, ll));
proposed_L.at(j, j) = sqrt(1 - dot(ll, ll));
}
//uword n = L.n_rows;
//uword column = upper_part.at(i) / n;
//vec ll = proposed_L.submat(0, column, column - 1, column);
//double ss = dot(ll, ll);
//if (ss > 1) return proposed_L.fill(datum::nan);
//proposed_L.at(column, column) = sqrt(1 - ss);
uword nn = ind_zero_D.n_rows;
for (uword j = 0; j < nn; ++j) {
uword j0 = ind_zero_D(j, 0);
uword j1 = ind_zero_D(j, 1);
proposed_L(j0, j1) = -sum(proposed_L.col(j0) % proposed_L.col(j1)) / proposed_L(j0, j0);
uword j0 = ind_zero_D.at(j, 0);
uword j1 = ind_zero_D.at(j, 1);
proposed_L.at(j0, j1) = - sum(proposed_L.col(j0) % proposed_L.col(j1)) / proposed_L.at(j0, j0);
vec ll = proposed_L.submat(0, j1, j1 - 1, j1);
double ss = dot(ll, ll);
if (ss > 1) return proposed_L.fill(datum::nan);
Expand Down

0 comments on commit 42ef08d

Please sign in to comment.