diff --git a/R/cluster_analysis.R b/R/cluster_analysis.R index 4d10ec0f6..30af7cffd 100644 --- a/R/cluster_analysis.R +++ b/R/cluster_analysis.R @@ -452,16 +452,16 @@ visualisation_recipe.cluster_analysis <- function(x, show_data = "text", ...) { # Check number of columns: if more than 2, display PCs, if less, fail if (ncol(ori_data) <= 2) { - insight::format_error("Less than 2 variables in the dataset. Cannot compute enough principal components to represent clustering.") + insight::format_error("Less than 2 variables in the dataset. Cannot compute enough principal components to represent clustering.") # nolint } # Get 2 PCA Components pca <- principal_components(ori_data, n = 2) - data <- stats::predict(pca) - names(data) <- c("x", "y") - data$Cluster <- as.character(stats::na.omit(attributes(x)$clusters)) + prediction_data <- stats::predict(pca) + names(prediction_data) <- c("x", "y") + prediction_data$Cluster <- as.character(stats::na.omit(attributes(x)$clusters)) - data$label <- row.names(ori_data) + prediction_data$label <- row.names(ori_data) if (!is.null(show_data) && show_data %in% c("label", "text")) { label <- "label" } else { @@ -473,7 +473,7 @@ visualisation_recipe.cluster_analysis <- function(x, show_data = "text", ...) { data_centers$Cluster <- as.character(as.data.frame(x)$Cluster) # Outliers - data$Cluster[data$Cluster == "0"] <- NA + prediction_data$Cluster[prediction_data$Cluster == "0"] <- NA data_centers <- data_centers[data_centers$Cluster != "0", ] layers <- list() @@ -482,7 +482,7 @@ visualisation_recipe.cluster_analysis <- function(x, show_data = "text", ...) { layers[["l1"]] <- list( geom = show_data, - data = data, + data = prediction_data, aes = list(x = "x", y = "y", label = label, color = "Cluster") ) @@ -501,7 +501,7 @@ visualisation_recipe.cluster_analysis <- function(x, show_data = "text", ...) { # Out class(layers) <- c("visualisation_recipe", "see_visualisation_recipe", class(layers)) - attr(layers, "data") <- data + attr(layers, "data") <- prediction_data layers } diff --git a/R/utils_pca_efa.R b/R/utils_pca_efa.R index b8d0211fa..d10759ae7 100644 --- a/R/utils_pca_efa.R +++ b/R/utils_pca_efa.R @@ -150,34 +150,41 @@ predict.parameters_efa <- function(object, ...) { attri <- attributes(object) - if (inherits(attri$model, c("psych", "principal", "psych", "fa"))) { - if (is.null(newdata)) { - if ("scores" %in% names(attri)) { - out <- as.data.frame(attri$scores) - if (isTRUE(keep_na)) { - # Because pre-made scores don't preserve NA - out <- .merge_na(object, out) - } - } else { + # handle if no data is provided + if (is.null(newdata)) { + # check if we have scores attribute - these will be returned directly + if ("scores" %in% names(attri)) { + out <- as.data.frame(attri$scores) + if (isTRUE(keep_na)) { + out <- .merge_na(object, out, verbose) + } + } else { + # if we have data, use that for prediction + if ("dataset" %in% names(attri)) { d <- attri$data_set d <- d[vapply(d, is.numeric, logical(1))] - out <- as.data.frame(stats::predict(attri$model, data = d)) + out <- as.data.frame(stats::predict(attri$model, newdata = d)) + } else { + insight::format_error( + "Could not retrieve data nor model. Please report an issue on {.url https://github.com/easystats/parameters/issues}." # nolint + ) } - } else { - # psych:::predict.principal(object, data) - out <- stats::predict(attri$model, data = newdata) } - } else if (inherits(attri$model, "spca")) { - # https://github.com/erichson/spca/issues/7 - newdata <- newdata[names(attri$model$center)] - if (attri$standardize) { - newdata <- sweep(newdata, MARGIN = 2, STATS = attri$model$center, FUN = "-", check.margin = TRUE) - newdata <- sweep(newdata, MARGIN = 2, STATS = attri$model$scale, FUN = "/", check.margin = TRUE) - } - out <- as.matrix(newdata) %*% as.matrix(attri$model$loadings) - out <- stats::setNames(as.data.frame(out), paste0("Component", seq_len(ncol(out)))) } else { - out <- as.data.frame(stats::predict(attri$model, newdata = attri$dataset, ...)) + if (inherits(attri$model, "spca")) { + # https://github.com/erichson/spca/issues/7 + newdata <- newdata[names(attri$model$center)] + if (attri$standardize) { + newdata <- sweep(newdata, MARGIN = 2, STATS = attri$model$center, FUN = "-", check.margin = TRUE) + newdata <- sweep(newdata, MARGIN = 2, STATS = attri$model$scale, FUN = "/", check.margin = TRUE) + } + out <- as.matrix(newdata) %*% as.matrix(attri$model$loadings) + out <- stats::setNames(as.data.frame(out), paste0("Component", seq_len(ncol(out)))) + } else if (inherits(attri$model, c("psych", "fa", "principal"))) { + out <- as.data.frame(stats::predict(attri$model, newdata = newdata, ...)) + } else { + out <- as.data.frame(stats::predict(attri$model, newdata = newdata, ...)) + } } if (!is.null(names)) { diff --git a/tests/testthat/_snaps/visualisation_recipe.md b/tests/testthat/_snaps/visualisation_recipe.md new file mode 100644 index 000000000..199f0fc48 --- /dev/null +++ b/tests/testthat/_snaps/visualisation_recipe.md @@ -0,0 +1,36 @@ +# vis_recipe.cluster_analysis + + Code + print(out) + Output + Layer 1 + -------- + Geom type: text + data = [150 x 4] + aes_string( + x = 'x' + y = 'y' + label = 'label' + color = 'Cluster' + ) + + Layer 2 + -------- + Geom type: point + data = [4 x 3] + aes_string( + x = 'x' + y = 'y' + color = 'Cluster' + ) + shape = '+' + size = 10 + + Layer 3 + -------- + Geom type: labs + x = 'PCA - 1' + y = 'PCA - 2' + title = 'Clustering Solution' + + diff --git a/tests/testthat/test-visualisation_recipe.R b/tests/testthat/test-visualisation_recipe.R new file mode 100644 index 000000000..8661255df --- /dev/null +++ b/tests/testthat/test-visualisation_recipe.R @@ -0,0 +1,8 @@ +test_that("vis_recipe.cluster_analysis", { + data(iris) + result <- cluster_analysis(iris[, 1:4], n = 4) + out <- visualisation_recipe(result) + expect_named(out, c("l1", "l2", "l3")) + expect_s3_class(out, "visualisation_recipe") + expect_snapshot(print(out)) +})