Skip to content

Commit

Permalink
unnormalize() with grouped data (#415)
Browse files Browse the repository at this point in the history
* working example

* fix small bug

* some comments

* fix order of attributes

* refactor unnormalize for groups

* test error message

* fix test

* fix args of unnormalize

* clean some lints

* start same for `unstandardize()` [skip ci]

* typo

* styler

* fix unstandardize() for groups

* remove former test

* ensure that both functions return a grouped dataframe

* lintr, styler

* lintr, styler, minor doc

* add some extra checks

* fix

* unnormalize: from warning to error

* bump news

---------

Co-authored-by: Daniel <[email protected]>
  • Loading branch information
etiennebacher and strengejacke authored Sep 12, 2023
1 parent ae7df24 commit ad96b50
Show file tree
Hide file tree
Showing 12 changed files with 391 additions and 38 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ S3method(to_numeric,logical)
S3method(to_numeric,numeric)
S3method(unnormalize,data.frame)
S3method(unnormalize,default)
S3method(unnormalize,grouped_df)
S3method(unnormalize,numeric)
S3method(unstandardize,array)
S3method(unstandardize,character)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ CHANGES

* `datawizard` moves from the GPL-3 license to the MIT license.

* `unnormalize()` and `unstandardize()` now work with grouped data (#415).

* `unnormalize()` now errors instead of emitting a warning if it doesn't have the
necessary info (#415).

BUG FIXES

* Fixed issue in `labels_to_levels()` when values of labels were not in sorted
Expand Down
24 changes: 21 additions & 3 deletions R/normalize.R
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,15 @@ normalize.grouped_df <- function(x,
}

x <- as.data.frame(x)
for (rows in grps) {
x[rows, ] <- normalize(
x[rows, , drop = FALSE],

# create column(s) to store dw_transformer attributes
for (i in select) {
info$groups[[paste0("attr_", i)]] <- rep(NA, length(grps))
}

for (rows in seq_along(grps)) {
tmp <- normalize(
x[grps[[rows]], , drop = FALSE],
select = select,
exclude = exclude,
include_bounds = include_bounds,
Expand All @@ -225,9 +231,21 @@ normalize.grouped_df <- function(x,
add_transform_class = FALSE,
...
)

# store dw_transformer_attributes
for (i in select) {
info$groups[rows, paste0("attr_", i)][[1]] <- list(unlist(attributes(tmp[[i]])))
}

x[grps[[rows]], ] <- tmp
}

# last column of "groups" attributes must be called ".rows"
info$groups <- data_relocate(info$groups, ".rows", after = -1)

# set back class, so data frame still works with dplyr
attributes(x) <- utils::modifyList(info, attributes(x))
class(x) <- c("grouped_df", class(x))
x
}

Expand Down
37 changes: 28 additions & 9 deletions R/standardize.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ standardize.matrix <- function(x, ...) {
x_out <- do.call(cbind, xz)
dimnames(x_out) <- dimnames(x)

attr(x_out, "center") <- sapply(xz, attr, "center")
attr(x_out, "scale") <- sapply(xz, attr, "scale")
attr(x_out, "robust") <- sapply(xz, attr, "robust")[1]
attr(x_out, "center") <- vapply(xz, attr, "center", FUN.VALUE = numeric(1L))
attr(x_out, "scale") <- vapply(xz, attr, "scale", FUN.VALUE = numeric(1L))
attr(x_out, "robust") <- vapply(xz, attr, "robust", FUN.VALUE = logical(1L))[1]
class(x_out) <- c("dw_transformer", class(x_out))

x_out
Expand Down Expand Up @@ -300,9 +300,12 @@ standardize.data.frame <- function(x,
)
}


attr(x, "center") <- sapply(x[args$select], function(z) attributes(z)$center)
attr(x, "scale") <- sapply(x[args$select], function(z) attributes(z)$scale)
attr(x, "center") <- unlist(lapply(x[args$select], function(z) {
attributes(z)$center
}))
attr(x, "scale") <- unlist(lapply(x[args$select], function(z) {
attributes(z)$scale
}))
attr(x, "robust") <- robust
x
}
Expand Down Expand Up @@ -341,9 +344,14 @@ standardize.grouped_df <- function(x,
reference, weights, keep_factors = force
)

for (rows in args$grps) {
args$x[rows, ] <- standardize(
args$x[rows, , drop = FALSE],
# create column(s) to store dw_transformer attributes
for (i in select) {
args$info$groups[[paste0("attr_", i)]] <- rep(NA, length(args$grps))
}

for (rows in seq_along(args$grps)) {
tmp <- standardize(
args$x[args$grps[[rows]], , drop = FALSE],
select = args$select,
exclude = NULL,
robust = robust,
Expand All @@ -358,7 +366,18 @@ standardize.grouped_df <- function(x,
add_transform_class = FALSE,
...
)

# store dw_transformer_attributes
for (i in select) {
args$info$groups[rows, paste0("attr_", i)][[1]] <- list(unlist(attributes(tmp[[i]])))
}

args$x[args$grps[[rows]], ] <- tmp
}

# last column of "groups" attributes must be called ".rows"
args$info$groups <- data_relocate(args$info$groups, ".rows", after = -1)

# set back class, so data frame still works with dplyr
attributes(args$x) <- args$info
args$x
Expand Down
105 changes: 99 additions & 6 deletions R/unnormalize.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,32 @@ unnormalize.default <- function(x, ...) {
#' @export
unnormalize.numeric <- function(x, verbose = TRUE, ...) {
## TODO implement algorithm include_bounds = FALSE
include_bounds <- attr(x, "include_bounds")
min_value <- attr(x, "min_value")
range_difference <- attr(x, "range_difference")
to_range <- attr(x, "to_range")

# if function called from the "grouped_df" method, we use the dw_transformer
# attributes that were recovered in the "grouped_df" method

dots <- match.call(expand.dots = FALSE)[["..."]]
grp_attr_dw <- eval(dots$grp_attr_dw, envir = parent.frame(1L))

if (!is.null(grp_attr_dw)) {
names(grp_attr_dw) <- gsub(".*\\.", "", names(grp_attr_dw))
include_bounds <- grp_attr_dw["include_bounds"]
min_value <- grp_attr_dw["min_value"]
range_difference <- grp_attr_dw["range_difference"]
to_range <- grp_attr_dw["to_range"]
if (is.na(to_range)) {
to_range <- NULL
}
} else {
include_bounds <- attr(x, "include_bounds")
min_value <- attr(x, "min_value")
range_difference <- attr(x, "range_difference")
to_range <- attr(x, "to_range")
}

if (is.null(min_value) || is.null(range_difference)) {
if (verbose) {
insight::format_warning("Can't unnormalize variable. Information about range and/or minimum value is missing.")
insight::format_error("Can't unnormalize variable. Information about range and/or minimum value is missing.")
}
return(x)
}
Expand Down Expand Up @@ -54,7 +72,82 @@ unnormalize.data.frame <- function(x,
regex = regex,
verbose = verbose
)
x[select] <- lapply(x[select], unnormalize, verbose = verbose)

# if function called from the "grouped_df" method, we use the dw_transformer
# attributes that were recovered in the "grouped_df" method

dots <- match.call(expand.dots = FALSE)[["..."]]

if (!is.null(dots$grp_attr_dw)) {
grp_attr_dw <- eval(dots$grp_attr_dw, envir = parent.frame(1L))
} else {
grp_attr_dw <- NULL
}

for (i in select) {
var_attr <- grep(paste0("^attr\\_", i, "\\."), names(grp_attr_dw))
attrs <- grp_attr_dw[var_attr]
x[[i]] <- unnormalize(x[[i]], verbose = verbose, grp_attr_dw = attrs)
}

x
}

#' @rdname normalize
#' @export
unnormalize.grouped_df <- function(x,
select = NULL,
exclude = NULL,
ignore_case = FALSE,
regex = FALSE,
verbose = TRUE,
...) {
# evaluate select/exclude, may be select-helpers
select <- .select_nse(select,
x,
exclude,
ignore_case,
regex = regex,
remove_group_var = TRUE,
verbose = verbose
)

info <- attributes(x)
# works only for dplyr >= 0.8.0
grps <- attr(x, "groups", exact = TRUE)[[".rows"]]

x <- as.data.frame(x)

for (i in select) {
if (is.null(info$groups[[paste0("attr_", i)]])) {
insight::format_error(
paste(
"Couldn't retrieve the necessary information to unnormalize",
text_concatenate(i, enclose = "`")
)
)
}
}
for (rows in seq_along(grps)) {
# get the dw_transformer attributes for this group
raw_attrs <- unlist(info$groups[rows, startsWith(names(info$groups), "attr")])
if (length(select) == 1L) {
names(raw_attrs) <- paste0("attr_", select, ".", names(raw_attrs))
}

tmp <- unnormalize(
x[grps[[rows]], , drop = FALSE],
select = select,
exclude = exclude,
ignore_case = ignore_case,
regex = regex,
verbose = verbose,
grp_attr_dw = raw_attrs
)
x[grps[[rows]], ] <- tmp
}
# set back class, so data frame still works with dplyr
attributes(x) <- utils::modifyList(info, attributes(x))
class(x) <- c("grouped_df", class(x))
x
}
79 changes: 77 additions & 2 deletions R/unstandardize.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,23 @@ unstandardize.data.frame <- function(x,
verbose = verbose
)

if (!is.null(reference)) {
dots <- match.call(expand.dots = FALSE)[["..."]]

if (!is.null(dots$grp_attr_dw)) {
grp_attr_dw <- eval(dots$grp_attr_dw, envir = parent.frame(1L))
} else {
grp_attr_dw <- NULL
}

if (!is.null(grp_attr_dw)) {
center <- vapply(cols, function(x) {
grp_attr_dw[grep(paste0("^attr\\_", x, "\\.center"), names(grp_attr_dw))]
}, FUN.VALUE = numeric(1L))
scale <- vapply(cols, function(x) {
grp_attr_dw[grep(paste0("^attr\\_", x, "\\.scale"), names(grp_attr_dw))]
}, FUN.VALUE = numeric(1L))
i <- vapply(x[, cols, drop = FALSE], is.numeric, FUN.VALUE = logical(1L))
} else if (!is.null(reference)) {
i <- vapply(x[, cols, drop = FALSE], is.numeric, FUN.VALUE = logical(1L))
i <- i[i]
reference <- reference[names(i)]
Expand Down Expand Up @@ -143,8 +159,67 @@ unstandardize.grouped_df <- function(x,
reference = NULL,
robust = FALSE,
two_sd = FALSE,
select = NULL,
exclude = NULL,
ignore_case = FALSE,
regex = FALSE,
verbose = TRUE,
...) {
insight::format_error("Cannot (yet) unstandardize a `grouped_df`.")
# evaluate select/exclude, may be select-helpers
select <- .select_nse(select,
x,
exclude,
ignore_case,
regex = regex,
remove_group_var = TRUE,
verbose = verbose
)

info <- attributes(x)

# works only for dplyr >= 0.8.0
grps <- attr(x, "groups", exact = TRUE)[[".rows"]]

x <- as.data.frame(x)

for (i in select) {
if (is.null(info$groups[[paste0("attr_", i)]])) {
insight::format_error(
paste(
"Couldn't retrieve the necessary information to unstandardize",
text_concatenate(i, enclose = "`")
)
)
}
}

for (rows in seq_along(grps)) {
# get the dw_transformer attributes for this group
raw_attrs <- unlist(info$groups[rows, startsWith(names(info$groups), "attr")])
if (length(select) == 1L) {
names(raw_attrs) <- paste0("attr_", select, ".", names(raw_attrs))
}

tmp <- unstandardise(
x[grps[[rows]], , drop = FALSE],
center = center,
scale = scale,
reference = reference,
robust = robust,
two_sd = two_sd,
select = select,
exclude = exclude,
ignore_case = ignore_case,
regex = regex,
verbose = verbose,
grp_attr_dw = raw_attrs
)
x[grps[[rows]], ] <- tmp
}
# set back class, so data frame still works with dplyr
attributes(x) <- utils::modifyList(info, attributes(x))
class(x) <- c("grouped_df", class(x))
x
}

#' @export
Expand Down
13 changes: 3 additions & 10 deletions man/describe_distribution.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions man/normalize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ad96b50

Please sign in to comment.