diff --git a/R/plot.check_dag.R b/R/plot.check_dag.R index 8d91f5185..74bd3cc06 100644 --- a/R/plot.check_dag.R +++ b/R/plot.check_dag.R @@ -8,6 +8,8 @@ #' hex-format) for different types of variables. #' @param which Character string indicating which plot to show. Can be either #' `"all"`, `"current"` or `"required"`. +#' @param check_colliders Logical indicating whether to highlight colliders. +#' Set to `FALSE` if the algorithm to detect colliders is very slow. #' @param ... Not used. #' #' @return A ggplot2-object. @@ -27,23 +29,39 @@ #' # plot only model with required adjustments #' plot(dag, which = "required") #' @export -plot.see_check_dag <- function(x, size_point = 15, colors = NULL, which = "all", ...) { +plot.see_check_dag <- function(x, + size_point = 15, + colors = NULL, + which = "all", + check_colliders = TRUE, + ...) { .data <- NULL insight::check_if_installed(c("ggdag", "ggplot2")) which <- match.arg(which, choices = c("all", "current", "required")) - p1 <- suppressWarnings(ggdag::ggdag_adjust(x, stylized = TRUE)) - p2 <- suppressWarnings(ggdag::ggdag_adjustment_set(x, shadow = TRUE, stylized = TRUE)) + # get plot data + p1 <- p2 <- suppressWarnings(ggdag::dag_adjustment_sets(x)) + adjusted_for <- attributes(x)$adjusted + + # for current plot, we need to update the "adjusted" column + p1$data$adjusted <- "unadjusted" + if (!is.null(adjusted_for)) { + p1$data$adjusted[p1$data$name %in% adjusted_for] <- "adjusted" + } # tweak data p1$data$type <- as.character(p1$data$adjusted) - p1$data$type[vapply(p1$data$name, ggdag::is_collider, logical(1), .dag = x)] <- "collider" + if (check_colliders) { + p1$data$type[vapply(p1$data$name, ggdag::is_collider, logical(1), .dag = x)] <- "collider" + } p1$data$type[p1$data$name == attributes(x)$outcome] <- "outcome" p1$data$type[p1$data$name %in% attributes(x)$exposure] <- "exposure" p1$data$type <- factor(p1$data$type, levels = c("outcome", "exposure", "adjusted", "unadjusted", "collider")) p2$data$type <- as.character(p2$data$adjusted) - p2$data$type[vapply(p2$data$name, ggdag::is_collider, logical(1), .dag = x)] <- "collider" + if (check_colliders) { + p2$data$type[vapply(p2$data$name, ggdag::is_collider, logical(1), .dag = x)] <- "collider" + } p2$data$type[p2$data$name == attributes(x)$outcome] <- "outcome" p2$data$type[p2$data$name %in% attributes(x)$exposure] <- "exposure" p2$data$type <- factor(p2$data$type, levels = c("outcome", "exposure", "adjusted", "unadjusted", "collider")) diff --git a/man/plot.see_check_dag.Rd b/man/plot.see_check_dag.Rd index 5866cba81..0b8a0d1b4 100644 --- a/man/plot.see_check_dag.Rd +++ b/man/plot.see_check_dag.Rd @@ -4,7 +4,14 @@ \alias{plot.see_check_dag} \title{Plot method for check DAGs} \usage{ -\method{plot}{see_check_dag}(x, size_point = 15, colors = NULL, which = "all", ...) +\method{plot}{see_check_dag}( + x, + size_point = 15, + colors = NULL, + which = "all", + check_colliders = TRUE, + ... +) } \arguments{ \item{x}{A \code{check_dag} object.} @@ -17,6 +24,9 @@ hex-format) for different types of variables.} \item{which}{Character string indicating which plot to show. Can be either \code{"all"}, \code{"current"} or \code{"required"}.} +\item{check_colliders}{Logical indicating whether to highlight colliders. +Set to \code{FALSE} if the algorithm to detect colliders is very slow.} + \item{...}{Not used.} } \value{