-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
rforaita
committed
Mar 13, 2024
1 parent
c08bdef
commit 4771663
Showing
13 changed files
with
2,013 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#' Structure learning from missing data | ||
#' | ||
#' The \code{bnlear::\link[bnlearn]{structural.em}} function was adapted | ||
#' to our data since it could not impute some missing values for bmi_m.2 | ||
#' | ||
#' @return | ||
#' @export | ||
#' | ||
|
||
structural.em.rf <- function (x, maximize = "hc", maximize.args = list(), fit = "mle", | ||
fit.args = list(), impute, impute.args = list(), return.all = FALSE, | ||
start = NULL, max.iter = 5, debug = FALSE) | ||
{ | ||
ntests = 0 | ||
data.info = bnlearn:::check.data(x, allow.levels = TRUE, allow.missing = TRUE, | ||
warn.if.no.missing = TRUE, stop.if.all.missing = !is(start, | ||
"bn.fit")) | ||
max.iter = bnlearn:::check.max.iter(max.iter) | ||
bnlearn:::check.logical(debug) | ||
bnlearn:::check.logical(return.all) | ||
bnlearn:::check.learning.algorithm(algorithm = maximize, class = "score") | ||
critical.arguments = c("x", "heuristic", "start", "debug") | ||
bnlearn:::check.unused.args(intersect(critical.arguments, names(maximize.args)), | ||
character(0)) | ||
maximize.args[critical.arguments] = list(x = NULL, heuristic = maximize, | ||
start = NULL, debug = debug) | ||
bnlearn:::check.fitting.method(method = fit, data = x) | ||
fit.args = bnlearn:::check.fitting.args(method = fit, network = NULL, | ||
data = x, extra.args = fit.args) | ||
impute = bnlearn:::check.imputation.method(impute, x) | ||
impute.args = bnlearn:::check.imputation.extra.args(impute, impute.args) | ||
|
||
dag = empty.graph(nodes = names(x)) | ||
fitted = bnlearn:::bn.fit.backend(dag, data = x, method = fit, | ||
extra.args = fit.args, data.info = data.info) | ||
|
||
data.info$complete.nodes[names(x)] = TRUE | ||
for (i in seq(max.iter)) { | ||
print(i) | ||
complete = bnlearn:::impute.backend(fitted = fitted, data = x, | ||
method = impute, extra.args = impute.args, debug = debug) | ||
|
||
# new: start: BMI Mother at FU2 could not be estimated. We imputed here the mean | ||
# value. The number of missing values were < 10. | ||
if(any(is.na(complete))){ | ||
ml <- which(is.na(complete$bmi_m.2)) | ||
complete$bmi_m.2[ml] <- mean(complete$bmi_m.2, na.rm = TRUE) | ||
} | ||
# end | ||
maximize.args$x = complete | ||
maximize.args$start = dag | ||
dag = do.call(bnlearn:::greedy.search, maximize.args) | ||
fitted.new = bnlearn:::bn.fit.backend(dag, data = complete, method = fit, | ||
extra.args = fit.args, data.info = data.info) | ||
|
||
ntests = ntests + dag$learning$ntests | ||
if (isTRUE(all.equal(fitted, fitted.new))) | ||
break | ||
else fitted = fitted.new | ||
} | ||
dag$learning$algo = "sem" | ||
dag$learning$maximize = maximize | ||
dag$learning$impute = impute | ||
dag$learning$fit = fit | ||
dag$learning$ntests = ntests | ||
if (return.all) | ||
invisible(list(dag = dag, imputed = complete, fitted = fitted)) | ||
else invisible(dag) | ||
} | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# ------------------------------------------------------------------------------ | ||
# | ||
# Project: Cohort Causal Graph | ||
# | ||
# Author: R. Foraita | ||
# Date: JUL 2021 | ||
# | ||
# Purpose: RMSEU | ||
# | ||
# ------------------------------------------------------------------------------ | ||
#boot <- readRDS("data_not_load/boot.RDS") | ||
boot <- readRDS("data_not_load/boot100_mi1-graph.RDS") | ||
|
||
|
||
# make adjacency matrix of all pc-graphs | ||
amat <- lapply(boot, function(x){ | ||
tmp <- wgtMatrix(getGraph(x), transpose = FALSE) | ||
wm2 <- (tmp + t(tmp)) | ||
# undirected edges get 0.5 (in sum they count as 1) | ||
tmp[which(wm2 > 1)] <- 0.5 | ||
tmp | ||
}) | ||
|
||
## average edges | ||
g.avg <- Reduce('+', amat) | ||
g.avg <- Reduce('+', amat) / length(amat) | ||
|
||
|
||
|
||
# sum up directed edges | ||
t1 <- t(g.avg)[lower.tri(t(g.avg))] # untere Hälfte | ||
t2 <- g.avg[lower.tri(g.avg)] # obere Hälfte | ||
t3 <- apply(cbind(t1, t2), 1, sum) # Summe | ||
|
||
# select_all_but_diag <- function(x){ | ||
# apply(cbind(t(x)[lower.tri(x, diag = F)], x[lower.tri(x, diag = F)]), 1, sum) | ||
# } | ||
# a <- select_all_but_diag(g.avg) | ||
|
||
|
||
|
||
|
||
t44 <- t50 <- t75 <- t3 | ||
t44[t44 < 0.44] <- 0 | ||
t50[t50 < 0.50] <- 0 | ||
t75[t75 < 0.75] <- 0 | ||
(rmseu0 <- gum(t3, vertices = nrow(g.avg), threshold = 0.5)) | ||
(rmseu44 <- gum(t44, vertices = nrow(g.avg), threshold = 0.5)) | ||
(rmseu50 <- gum(t50, vertices = nrow(g.avg), threshold = 0.5)) | ||
(rmseu75 <- gum(t75, vertices = nrow(g.avg), threshold = 0.5)) | ||
|
||
# compare - RMSEU = 0.183 | ||
gum(rep(0.9085, times = choose(51,2)), vertices = 51, threshold = 0.5)$rmseu | ||
1-0.9085 | ||
# MEU 0.0443451 | ||
gum(rep(0.005, times = choose(51,2)), vertices = 51, threshold = 0.5)$meu | ||
1-0.9085 | ||
|
||
# compare - RMSEU = 0.246 | ||
gum(rep(0.877, times = choose(51,2)), vertices = 51, threshold = 0.5) | ||
1-0.877 | ||
# compare - MEE = 0.059 | ||
gum(rep(0.0069, times = choose(51,2)), vertices = 51, threshold = 0.5)$mee | ||
1-0.0069 | ||
|
||
|
||
# compare - 0.183 for edges that were at least selected once (here 570) | ||
gum(rep(0.9085, times = 570), vertices = 34, threshold = 0.5)$rmseu | ||
gum(rep(0.877, times = 570), vertices = 34, threshold = 0.5)$rmseu | ||
|
||
|
||
|
||
|
||
|
||
# compare - 0.267481 | ||
gum(rep(0.8665, times = choose(51,2)), vertices = 51, threshold = 0.5)$rmseu | ||
gum(rep(1-0.8665, times = choose(51,2)), vertices = 51, threshold = 0.5)$rmseu | ||
gum(matrix(rep(0.0936, times = 51*51), nrow = 51, ncol = 51), threshold = 0.5)$rmseu | ||
gum(matrix(rep(0.35, times = 51*51), nrow = 51, ncol = 51), threshold = 0.5) | ||
|
||
# summary(t3): mean = 0.09917, rmseu = 0.1048 | ||
gum(rep(0.2457, times = choose(51,2)), 51, threshold = 0.75)$rmseu | ||
gum(rep(0.2457, times = choose(51,2)), 51, threshold = 0.25)$rmseu | ||
gum(rep(0.09917, times = choose(51,2)), 51, threshold = 0.05)$rmseu | ||
gum(rep(0.09917, times = choose(51,2)), 51, threshold = 0.05)$mee | ||
|
||
# --- save ----------------------------------------------- | ||
save(rmseu0, rmseu44, rmseu50, rmseu75, file = "results/rmseu.RData") | ||
|
||
|
||
mymeu <- function(x){ | ||
0.0443451 - gum(rep(x, times = choose(51,2)), vertices = 51, threshold = 0.5)$meu | ||
} | ||
optim(seq(0.01,0.15, length = 100), mymeu) | ||
|
||
|
||
|
||
|
||
# uncertaintz shcie- 10 Knoten -> 45 kanten | ||
medges <- rep(0.5) | ||
gum(medges, vertices = 2, threshold = 0.05)$mee |
Oops, something went wrong.