Skip to content

Commit

Permalink
Fix predict for cluster centers in vis-recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
strengejacke committed Dec 6, 2023
1 parent c28bdd0 commit 189cf82
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 31 deletions.
16 changes: 8 additions & 8 deletions R/cluster_analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -452,16 +452,16 @@ visualisation_recipe.cluster_analysis <- function(x, show_data = "text", ...) {

# Check number of columns: if more than 2, display PCs, if less, fail
if (ncol(ori_data) <= 2) {
insight::format_error("Less than 2 variables in the dataset. Cannot compute enough principal components to represent clustering.")
insight::format_error("Less than 2 variables in the dataset. Cannot compute enough principal components to represent clustering.") # nolint
}

# Get 2 PCA Components
pca <- principal_components(ori_data, n = 2)
data <- stats::predict(pca)
names(data) <- c("x", "y")
data$Cluster <- as.character(stats::na.omit(attributes(x)$clusters))
prediction_data <- stats::predict(pca)
names(prediction_data) <- c("x", "y")
prediction_data$Cluster <- as.character(stats::na.omit(attributes(x)$clusters))

data$label <- row.names(ori_data)
prediction_data$label <- row.names(ori_data)
if (!is.null(show_data) && show_data %in% c("label", "text")) {
label <- "label"
} else {
Expand All @@ -473,7 +473,7 @@ visualisation_recipe.cluster_analysis <- function(x, show_data = "text", ...) {
data_centers$Cluster <- as.character(as.data.frame(x)$Cluster)

# Outliers
data$Cluster[data$Cluster == "0"] <- NA
prediction_data$Cluster[prediction_data$Cluster == "0"] <- NA
data_centers <- data_centers[data_centers$Cluster != "0", ]

layers <- list()
Expand All @@ -482,7 +482,7 @@ visualisation_recipe.cluster_analysis <- function(x, show_data = "text", ...) {

layers[["l1"]] <- list(
geom = show_data,
data = data,
data = prediction_data,
aes = list(x = "x", y = "y", label = label, color = "Cluster")
)

Expand All @@ -501,7 +501,7 @@ visualisation_recipe.cluster_analysis <- function(x, show_data = "text", ...) {

# Out
class(layers) <- c("visualisation_recipe", "see_visualisation_recipe", class(layers))
attr(layers, "data") <- data
attr(layers, "data") <- prediction_data
layers
}

Expand Down
53 changes: 30 additions & 23 deletions R/utils_pca_efa.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,34 +150,41 @@ predict.parameters_efa <- function(object,
...) {
attri <- attributes(object)

if (inherits(attri$model, c("psych", "principal", "psych", "fa"))) {
if (is.null(newdata)) {
if ("scores" %in% names(attri)) {
out <- as.data.frame(attri$scores)
if (isTRUE(keep_na)) {
# Because pre-made scores don't preserve NA
out <- .merge_na(object, out)
}
} else {
# handle if no data is provided
if (is.null(newdata)) {
# check if we have scores attribute - these will be returned directly
if ("scores" %in% names(attri)) {
out <- as.data.frame(attri$scores)
if (isTRUE(keep_na)) {
out <- .merge_na(object, out, verbose)
}
} else {
# if we have data, use that for prediction
if ("dataset" %in% names(attri)) {
d <- attri$data_set
d <- d[vapply(d, is.numeric, logical(1))]
out <- as.data.frame(stats::predict(attri$model, data = d))
out <- as.data.frame(stats::predict(attri$model, newdata = d))
} else {
insight::format_error(
"Could not retrieve data nor model. Please report an issue on {.url https://github.com/easystats/parameters/issues}." # nolint
)
}
} else {
# psych:::predict.principal(object, data)
out <- stats::predict(attri$model, data = newdata)
}
} else if (inherits(attri$model, "spca")) {
# https://github.com/erichson/spca/issues/7
newdata <- newdata[names(attri$model$center)]
if (attri$standardize) {
newdata <- sweep(newdata, MARGIN = 2, STATS = attri$model$center, FUN = "-", check.margin = TRUE)
newdata <- sweep(newdata, MARGIN = 2, STATS = attri$model$scale, FUN = "/", check.margin = TRUE)
}
out <- as.matrix(newdata) %*% as.matrix(attri$model$loadings)
out <- stats::setNames(as.data.frame(out), paste0("Component", seq_len(ncol(out))))
} else {
out <- as.data.frame(stats::predict(attri$model, newdata = attri$dataset, ...))
if (inherits(attri$model, "spca")) {
# https://github.com/erichson/spca/issues/7
newdata <- newdata[names(attri$model$center)]
if (attri$standardize) {
newdata <- sweep(newdata, MARGIN = 2, STATS = attri$model$center, FUN = "-", check.margin = TRUE)
newdata <- sweep(newdata, MARGIN = 2, STATS = attri$model$scale, FUN = "/", check.margin = TRUE)
}
out <- as.matrix(newdata) %*% as.matrix(attri$model$loadings)
out <- stats::setNames(as.data.frame(out), paste0("Component", seq_len(ncol(out))))
} else if (inherits(attri$model, c("psych", "fa", "principal"))) {
out <- as.data.frame(stats::predict(attri$model, newdata = newdata, ...))
} else {
out <- as.data.frame(stats::predict(attri$model, newdata = newdata, ...))
}
}

if (!is.null(names)) {
Expand Down
36 changes: 36 additions & 0 deletions tests/testthat/_snaps/visualisation_recipe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# vis_recipe.cluster_analysis

Code
print(out)
Output
Layer 1
--------
Geom type: text
data = [150 x 4]
aes_string(
x = 'x'
y = 'y'
label = 'label'
color = 'Cluster'
)
Layer 2
--------
Geom type: point
data = [4 x 3]
aes_string(
x = 'x'
y = 'y'
color = 'Cluster'
)
shape = '+'
size = 10
Layer 3
--------
Geom type: labs
x = 'PCA - 1'
y = 'PCA - 2'
title = 'Clustering Solution'

8 changes: 8 additions & 0 deletions tests/testthat/test-visualisation_recipe.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
test_that("vis_recipe.cluster_analysis", {
data(iris)
result <- cluster_analysis(iris[, 1:4], n = 4)
out <- visualisation_recipe(result)
expect_named(out, c("l1", "l2", "l3"))
expect_s3_class(out, "visualisation_recipe")
expect_snapshot(print(out))
})

0 comments on commit 189cf82

Please sign in to comment.