Skip to content

Commit

Permalink
Removing model option as now using SVMs
Browse files Browse the repository at this point in the history
  • Loading branch information
ZekeMarshall committed Nov 4, 2024
1 parent 512b253 commit bbe8b03
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 155 deletions.
51 changes: 42 additions & 9 deletions R/graph_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ create_feature_importance_plot <- function(fe_data, bar_width = 5) {

expl_df <- fe_data |>
dplyr::select(-species) |>
dplyr::group_by(model, variable) |>
dplyr::group_by(#model,
variable) |>
dplyr::mutate("min" = min(dropout_loss, na.rm = TRUE),
"q1" = quantile(dropout_loss, probs = 0.25, na.rm = TRUE),
"median" = median(dropout_loss, na.rm = TRUE),
Expand All @@ -15,20 +16,22 @@ create_feature_importance_plot <- function(fe_data, bar_width = 5) {
# Add an additional column that serve as a baseline
bestFits <- expl_df |>
dplyr::filter(variable == "_full_model_") |>
dplyr::select(model, permutation, dropout_loss)
dplyr::select(#model,
permutation, dropout_loss)

ext_expl_df <- expl_df |>
dplyr::full_join(bestFits, by = c("model", "permutation"))
dplyr::full_join(bestFits, by = c(#"model",
"permutation"))

# Remove rows that starts with _ i.e. _full_model_ and _baseline_
ext_expl_df <- ext_expl_df |>
dplyr::filter(!(variable %in% c("_full_model_", "_baseline_")))

# Order rows
ext_expl_df <- ext_expl_df |>
dplyr::group_by(model) |>
dplyr::arrange(dplyr::desc(dropout_loss.x), .by_group = TRUE) |>
dplyr::ungroup()
# dplyr::group_by(model) |>
dplyr::arrange(dplyr::desc(dropout_loss.x))# , .by_group = TRUE) |>
# dplyr::ungroup()

# facets have fixed space, can be resolved with ggforce https://github.com/tidyverse/ggplot2/issues/2933
pl <- ggplot2::ggplot(data = ext_expl_df) +
Expand All @@ -42,7 +45,7 @@ create_feature_importance_plot <- function(fe_data, bar_width = 5) {
mapping = ggplot2::aes(x = variable, ymin = min, lower = q1, middle = median, upper = q3, ymax = max),
stat = "identity", fill = "#371ea3", color = "#371ea3", width = 0.25) +
ggplot2::coord_flip() +
ggplot2::facet_wrap(~model, ncol = 3, scales = "free_y") +
# ggplot2::facet_wrap(~model, ncol = 3, scales = "free_y") +
ggplot2::theme_minimal() +
ggplot2::theme(legend.position = "none") +
ggplot2::ylab(label = NULL) +
Expand All @@ -56,7 +59,7 @@ create_feature_importance_plot <- function(fe_data, bar_width = 5) {
create_ale_plot <- function(ale_data){

ale_plot <- ggplot2::ggplot(data = ale_data) +
ggplot2::geom_line(mapping = ggplot2::aes(x = x, y = y, color = model)) +
ggplot2::geom_line(mapping = ggplot2::aes(x = x, y = y)) + #, color = model)) +
ggplot2::facet_wrap(~variable, scales = "free_x", ncol = 2) +
ggplot2::theme_minimal() +
ggplot2::theme(legend.position = "right") +
Expand Down Expand Up @@ -135,4 +138,34 @@ plot_break_down <- function(x,

return(pl)

}
}

# create_pairs_plot <- function(pa_plot_metadata, variables, target_name, focal_species){
#
# # data_species <- purrr::map_depth(pa_plot_metadata, 1, purrr::pluck(focal_species)) |>
# # purrr::discard(is.null) |>
# # purrr::pluck(1)
#
# pairs_plot <- GGally::ggpairs(data_species, mapping = ggplot2::aes(color = .data[[target_name]]),
# upper = list(continuous = GGally::wrap("cor",
# size = 2.5),
# combo = GGally::wrap("box_no_facet",
# color = "#000000",
# linewidth = 0.5),
# discrete = "count",
# na = "na"),
# lower = list(continuous = GGally::wrap("points", size = 0.25, alpha = 0.8),
# combo = GGally::wrap("facethist", color = "#000000", linewidth = 0.5),
# discrete = GGally::wrap("facetbar", color = "#000000", linewidth = 0.5),
# na = "na"),
# diag = list(continuous = GGally::wrap("densityDiag", color = "#000000", linewidth = 0.5),
# discrete = GGally::wrap("barDiag", color = "#000000", linewidth = 0.5), na = "naDiag")) +
# ggplot2::scale_fill_viridis_d(end = 0.8, alpha = 0.8, option = "plasma") +
# ggplot2::scale_color_viridis_d(end = 0.8, alpha = 0.8, option = "plasma") +
# ggplot2::theme_minimal() +
# ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90),
# text = ggplot2::element_text(size = 10))
#
# return(pairs_plot)
#
# }
80 changes: 42 additions & 38 deletions inst/app/app.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,52 @@ options(shiny.autoload.r = FALSE)
# DESCRIPTION file

# Load required packages --------------------------------------------------
library(bookdown)
library(bsicons)
library(bslib)
library(dplyr)
library(ggplot2)
library(htmltools)
library(htmlwidgets)
library(kableExtra)
library(knitr)
library(magrittr)
library(markdown)
library(plotly)
library(purrr)
library(reactable)
library(readr)
library(rhandsontable)
library(rmarkdown)
library(shiny)
library(shinybusy)
library(shinyjs)
library(shinyWidgets)
library(stringr)
library(tibble)
library(tidyr)
library(vegan)
library(writexl)
suppressPackageStartupMessages({
library(bookdown)
library(bsicons)
library(bslib)
library(dplyr)
library(ggplot2)
library(htmltools)
library(htmlwidgets)
library(kableExtra)
library(knitr)
library(magrittr)
library(markdown)
library(plotly)
library(purrr)
library(reactable)
library(readr)
library(rhandsontable)
library(rmarkdown)
library(shiny)
library(shinybusy)
library(shinyjs)
library(shinyWidgets)
library(stringr)
library(tibble)
library(tidyr)
library(vegan)
library(writexl)
})

# TEMP FOR DEVELOPMENT ----------------------------------------------------
# Reading the niche model results from targets needs to be replaced with a
# package
library(mlr3)
library(mlr3pipelines)
library(mlr3learners)
library(mlr3extralearners)
library(targets)
library(DBI)
library(dbplyr)
library(duckdb)
library(qs)
library(stats)
library(DALEX)
library(DALEXtra)
suppressPackageStartupMessages({
library(mlr3)
library(mlr3pipelines)
library(mlr3learners)
library(mlr3extralearners)
library(targets)
library(DBI)
library(dbplyr)
library(duckdb)
library(qs)
library(stats)
library(DALEX)
library(DALEXtra)
})

source("./../../R/temp_functions.R", local = TRUE)
source("./../../R/graph_functions.R", local = TRUE)
Expand Down
50 changes: 21 additions & 29 deletions inst/app/modules/niche_models/nmModelDisplay_server.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {

# Retrieve sidebar options ------------------------------------------------
focalSpecies <- reactiveVal()
selectedModelDisplay <- reactiveVal()
# selectedModelDisplay <- reactiveVal()
selectedVariablesDisplay <- reactiveVal()
selectedMarginalEffectsPlot <- reactiveVal()

observe({

focalSpecies(sidebar_nm_options()$focalSpecies)
selectedModelDisplay(sidebar_nm_options()$selectedModelDisplay)
# selectedModelDisplay(sidebar_nm_options()$selectedModelDisplay)
selectedVariablesDisplay(sidebar_nm_options()$selectedVariablesDisplay)
selectedMarginalEffectsPlot(sidebar_nm_options()$selectedMarginalEffectsPlot)

Expand All @@ -35,7 +35,7 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {
)

focalSpecies <- focalSpecies()
selectedModelDisplay <- selectedModelDisplay()
# selectedModelDisplay <- selectedModelDisplay()
selectedMarginalEffectsPlot <- selectedMarginalEffectsPlot()
selectedVariablesDisplay <- c("_full_model_", "_baseline_", selectedVariablesDisplay())

Expand All @@ -47,31 +47,23 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {
# Retrieve measures
measures <- dplyr::tbl(src = con, "AllMeasures") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
# dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::collect()

# Retrieve marginal effects
if(selectedMarginalEffectsPlot == "ALE"){

meData <- dplyr::tbl(src = con, "AllALEData") |>
meData <- dplyr::tbl(src = con, "AllIMLALEData") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
# dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::filter(variable %in% selectedVariablesDisplay) |>
dplyr::collect()

} else if(selectedMarginalEffectsPlot == "PDP"){

meData <- dplyr::tbl(src = con, "AllPDPData") |>
meData <- dplyr::tbl(src = con, "AllIMLPDPData") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::filter(variable %in% selectedVariablesDisplay) |>
dplyr::collect()

} else if(selectedMarginalEffectsPlot == "CP"){

meData <- dplyr::tbl(src = con, "AllCDData") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
# dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::filter(variable %in% selectedVariablesDisplay) |>
dplyr::collect()

Expand All @@ -80,7 +72,7 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {
# Retrieve feature importance
featureImportance <- dplyr::tbl(src = con, "AllFeatureImportance") |>
dplyr::filter(species == focalSpecies) |>
dplyr::filter(model %in% selectedModelDisplay) |>
# dplyr::filter(model %in% selectedModelDisplay) |>
dplyr::filter(variable %in% selectedVariablesDisplay) |>
dplyr::collect()

Expand All @@ -99,14 +91,14 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {

}) |>
bindEvent(focalSpecies(),
selectedModelDisplay(),
# selectedModelDisplay(),
selectedVariablesDisplay(),
selectedMarginalEffectsPlot(),
ignoreInit = TRUE)


# Model evaluation metrics ------------------------------------------------
modelEvalMetricsTable_init <- data.frame("Model" = character(),
modelEvalMetricsTable_init <- data.frame(#"Model" = character(),
"Binary Brier" = double(),
"PRAUC" = double(),
"Precision" = double(),
Expand Down Expand Up @@ -154,16 +146,16 @@ nmModelDisplay <- function(input, output, session, sidebar_nm_options) {
measures <- measures_rval()

modelEvalMetricsTable_data <- measures |>
dplyr::mutate(
"model_type" = dplyr::case_when(
model %in% c("WE") ~ "WE",
TRUE ~ "Individual"
),
.before = model
) |>
dplyr::arrange(model_type, dplyr::desc(bacc)) |>
dplyr::select(-model_type) |>
dplyr::select("Model" = "model",
# dplyr::mutate(
# "model_type" = dplyr::case_when(
# model %in% c("WE") ~ "WE",
# TRUE ~ "Individual"
# ),
# .before = model
# ) |>
# dplyr::arrange(model_type, dplyr::desc(bacc)) |>
# dplyr::select(-model_type) |>
dplyr::select(#"Model" = "model",
"Binary Brier" = "bbrier",
# "Log Loss" = "logloss",
# "AUC" = "auc",
Expand Down
18 changes: 9 additions & 9 deletions inst/app/modules/niche_models/nmModelRun_server.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ nmModelRun <- function(input, output, session, sidebar_nm_options, nmDataInput)
runNMAnalysis <- reactiveVal()
focalSpecies <- reactiveVal()
identifyPredDrivers <- reactiveVal()
selectedModelPredict <- reactiveVal()
# selectedModelPredict <- reactiveVal()

observe({
runNMAnalysis(sidebar_nm_options()$runNMAnalysis)
focalSpecies(sidebar_nm_options()$focalSpecies)
identifyPredDrivers(sidebar_nm_options()$identifyPredDrivers)
selectedModelPredict(sidebar_nm_options()$selectedModelPredict)
# selectedModelPredict(sidebar_nm_options()$selectedModelPredict)

}) |>
bindEvent(sidebar_nm_options(),
Expand All @@ -21,9 +21,9 @@ nmModelRun <- function(input, output, session, sidebar_nm_options, nmDataInput)

# Retrieve Predictor Data -------------------------------------------------
predictorData_rval <- reactiveVal(predictors <- tibble::tribble(
~id, ~`F`, ~L, ~N, ~R, ~S, ~DG, ~DS, ~H,
"nvc_1000897", 4, 7.2, 7, 6.4, 0, 0.278, 0.177, 0.200,
"nvc_1000898", 3.86, 7.57, 2.71, 6.57, 0.571, 0.271, 0.219, 0.226
~id, ~`F`, ~N, ~R, ~S, ~DG, ~DS, ~H, ~MAP, ~Tmax07, ~Tmin01,
"nvc_1000897", 6, 4, 7, 0, 0.2, 0.1, 0.200, 1000, 25, 0,
"nvc_1000898", 3.86, 2.71, 6.57, 0.571, 0.271, 0.219, 0.226, 500, 20, 2
))

# observe({
Expand All @@ -40,10 +40,10 @@ nmModelRun <- function(input, output, session, sidebar_nm_options, nmDataInput)
observe({

focalSpecies <- focalSpecies()
selectedModelPredict <- selectedModelPredict()
# selectedModelPredict <- selectedModelPredict()

models <- targets::tar_read(name = "GAMModels", store = tar_store)
explainers <- targets::tar_read(name = "GAMDALEXExplainer", store = tar_store)
models <- targets::tar_read(name = "SVMModels", store = tar_store)
explainers <- targets::tar_read(name = "SVMDALEXExplainer", store = tar_store)

model <- retrieve_nested_element(nested_list = models,
focal_species = focalSpecies)
Expand All @@ -57,7 +57,7 @@ nmModelRun <- function(input, output, session, sidebar_nm_options, nmDataInput)

}) |>
bindEvent(focalSpecies(),
selectedModelPredict(),
# selectedModelPredict(),
ignoreInit = TRUE)


Expand Down
10 changes: 5 additions & 5 deletions inst/app/modules/niche_models/nmSidebar_server.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ nmSidebar <- function(input, output, session) {
nmSidebar_options_list <- list(
"runNMAnalysis" = input$runNMAnalysis,
"focalSpecies" = input$focalSpecies,
"selectedModelDisplay" = input$selectedModelDisplay,
# "selectedModelDisplay" = input$selectedModelDisplay,
"selectedVariablesDisplay" = input$selectedVariablesDisplay,
"selectedMarginalEffectsPlot" = input$selectedMarginalEffectsPlot,
"identifyPredDrivers" = input$identifyPredDrivers,
"selectedModelPredict" = input$selectedModelPredict
"identifyPredDrivers" = input$identifyPredDrivers#,
# "selectedModelPredict" = input$selectedModelPredict

)

Expand All @@ -23,11 +23,11 @@ nmSidebar <- function(input, output, session) {
}) |>
bindEvent(input$runNMAnalysis,
input$focalSpecies,
input$selectedModelDisplay,
# input$selectedModelDisplay,
input$selectedVariablesDisplay,
input$selectedMarginalEffectsPlot,
input$identifyPredDrivers,
input$selectedModelPredict,
# input$selectedModelPredict,
ignoreInit = TRUE)


Expand Down
Loading

0 comments on commit bbe8b03

Please sign in to comment.