Skip to content

Commit

Permalink
Add tests for new size_by argument
Browse files Browse the repository at this point in the history
 - While the implementation and tests are done, I think we still want to revist the NA handling, which also impacts the "start" node size.
  • Loading branch information
barrettk committed Oct 4, 2024
1 parent 55f6d44 commit ca905f7
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions tests/testthat/test-model-tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ context("Model tree diagram")
skip_if_not_ci_or_metworx("test-model-tree")
skip_if_tree_missing_deps()

# These two functions ignore the 'start' node, as we are only comparing
# to the run_log

# Count how many nodes appear in the model tree for each model
count_nodes <- function(tree_list) {
if(length(tree_list) == 0) return(0)
# Iterate through each element in the list
Expand All @@ -17,6 +21,26 @@ count_nodes <- function(tree_list) {
return(total_nodes)
}

# Get node attribute for each model
get_node_attribute <- function(tree_list, attr = 'SizeOfNode') {
if (length(tree_list) == 0) return(numeric(0))
# Iterate through each element in the list
attribute_values <- numeric(0)
for (i in seq_along(tree_list)) {
# Check if the specified attribute exists in the current node
if (!is.null(tree_list[[i]][[attr]])) {
attr_value <- tree_list[[i]][[attr]]
if(is.factor(attr_value)) attr_value <- as.character(attr_value)
attribute_values <- c(attribute_values, attr_value)
}
# If the current node has children, recursively get the attribute from children
if (length(tree_list[[i]]$children) > 0) {
attribute_values <- c(attribute_values, get_node_attribute(tree_list[[i]]$children, attr))
}
}
return(attribute_values)
}

withr::with_options(list(bbr.bbi_exe_path = read_bbi_path()), {


Expand Down Expand Up @@ -348,6 +372,69 @@ withr::with_options(list(bbr.bbi_exe_path = read_bbi_path()), {
)
})

it("size_tree_by()", {
clean_test_enviroment(create_tree_models)

log_df <- run_log(MODEL_DIR) %>% dplyr::mutate(
size_col = as.numeric(run)
)

# Checks that the size increases with each node (like size_col, i.e. run number)
pl_tree <- model_tree(log_df, add_summary = FALSE, size_by = "size_col")
node_sizes <- get_node_attribute(pl_tree$x$data$children, attr = "SizeOfNode")
expect_true(all(diff(node_sizes) > 0))

### Data checks ###
# Test logical size_by
true_indices <- which(log_df$star)
false_indices <- which(!log_df$star)
pl_tree <- model_tree(log_df, add_summary = FALSE, size_by = "star")
node_sizes <- get_node_attribute(pl_tree$x$data$children, attr = "SizeOfNode")

tree_data <- make_tree_data(log_df, add_summary = FALSE)
tree_data_star <- size_tree_by(tree_data, size_by = "star")
data_sizes <- tree_data_star$node_size[-1]

# Checks that the TRUE values are larger than FALSE values
# - Checks the underlying data, and rendered node size
expect_true(all(node_sizes[true_indices] > node_sizes[false_indices]))
expect_true(all(data_sizes[true_indices] > data_sizes[false_indices]))

# Check if all the same value (works the same if TRUE or NA)
log_df2 <- log_df
log_df2$star <- FALSE
false_indices <- which(!log_df2$star)
pl_tree <- model_tree(log_df2, add_summary = FALSE, size_by = "star")
node_sizes <- get_node_attribute(pl_tree$x$data$children, attr = "SizeOfNode")

tree_data <- make_tree_data(log_df2, add_summary = FALSE)
tree_data_star <- size_tree_by(tree_data, size_by = "star")
data_sizes <- tree_data_star$node_size[-1]

# Checks that all values are the same size
# - Checks the underlying data, and rendered node size
expect_true(dplyr::n_distinct(node_sizes[false_indices]) == 1)
expect_true(dplyr::n_distinct(data_sizes[false_indices]) == 1)


# Test numeric size_by (gradient sizing) - mimics objective function
log_df <- log_df %>% dplyr::mutate(
size_col = abs(rnorm(nrow(log_df), mean = 1500, sd = 800))
)
size_col_vals <- log_df$size_col
pl_tree <- model_tree(log_df, add_summary = FALSE, size_by = "size_col")
node_sizes <- get_node_attribute(pl_tree$x$data$children, attr = "SizeOfNode")

tree_data <- make_tree_data(log_df, add_summary = FALSE, size_by = "size_col")
tree_data_size <- size_tree_by(tree_data, size_by = "size_col")
data_sizes <- tree_data_size$node_size[-1]

# Checks that the ordering is consistent
# - Checks the underlying data, and rendered node size
expect_equal(order(size_col_vals), order(node_sizes))
expect_equal(order(size_col_vals), order(data_sizes))
})

it("static plot", {
skip_if_tree_missing_deps(static = TRUE)
clean_test_enviroment(create_tree_models)
Expand Down

0 comments on commit ca905f7

Please sign in to comment.