diff --git a/R/check_dag.R b/R/check_dag.R index 6e513612a..45227f9d1 100644 --- a/R/check_dag.R +++ b/R/check_dag.R @@ -74,7 +74,7 @@ check_dag <- function(..., attr(dag, "check_direct") <- insight::compact_list(checks[[1]]) attr(dag, "check_total") <- insight::compact_list(checks[[2]]) - class(dag) <- c("check_dag", class(dag)) + class(dag) <- c(c("check_dag", "see_check_dag"), class(dag)) dag } @@ -167,7 +167,7 @@ print.check_dag <- function(x, ...) { #' @export -plot.check_dag <- function(x, size_point = 15, which = "all", ...) { +plot.check_dag <- function(x, size_point = 15, colors = NULL, which = "all", ...) { .data <- NULL insight::check_if_installed(c("ggdag", "ggplot2", "see")) which <- match.arg(which, choices = c("all", "current", "required")) @@ -184,7 +184,13 @@ plot.check_dag <- function(x, size_point = 15, which = "all", ...) { p2$data$type[p2$data$name == attributes(x)$outcome] <- "outcome" p2$data$type[p2$data$name %in% attributes(x)$exposure] <- "exposure" - point_colors <- see::metro_colors(c("red", "teal", "yellow", "blue grey")) + if (is.null(colors)) { + point_colors <- see::see_colors(c("red", "cyan", "yellow", "blue grey")) + } else if (length(colors) != 4) { + insight::format_error("`colors` must be a character vector with four color-values.") + } else { + point_colors <- colors + } names(point_colors) <- c("unadjusted", "exposure", "outcome", "adjusted") plot1 <- ggplot2::ggplot(p1$data, ggplot2::aes(x = .data$x, y = .data$y)) +