Skip to content

Commit

Permalink
https://github.com/easystats/easystats/issues/404
Browse files Browse the repository at this point in the history
  • Loading branch information
strengejacke committed May 16, 2024
1 parent d74a225 commit ea9486e
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 16 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: datawizard
Title: Easy Data Wrangling and Statistical Transformations
Version: 0.10.0.3
Version: 0.10.0.4
Authors@R: c(
person("Indrajeet", "Patil", , "[email protected]", role = "aut",
comment = c(ORCID = "0000-0003-1995-6531", Twitter = "@patilindrajeets")),
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# datawizard 0.10.1

BREAKING CHANGES

* Arguments named `group` or `group_by` will be deprecated in a future release.
Please use `by` instead. This affects following functions in *datawizard*.

* `data_partition()`

CHANGES

* `recode_into()` is more relaxed regarding checking the type of `NA` values.
Expand Down
23 changes: 15 additions & 8 deletions R/data_partition.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
#'
#' Creates data partitions (for instance, a training and a test set) based on a
#' data frame that can also be stratified (i.e., evenly spread a given factor)
#' using the `group` argument.
#' using the `by` argument.
#'
#' @inheritParams data_rename
#' @param proportion Scalar (between 0 and 1) or numeric vector, indicating the
#' proportion(s) of the training set(s). The sum of `proportion` must not be
#' greater than 1. The remaining part will be used for the test set.
#' @param group A character vector indicating the name(s) of the column(s) used
#' @param by A character vector indicating the name(s) of the column(s) used
#' for stratified partitioning.
#' @param seed A random number generator seed. Enter an integer (e.g. 123) so
#' that the random sampling will be the same each time you run the function.
#' @param row_id Character string, indicating the name of the column that
#' contains the row-id's.
#' @param verbose Toggle messages and warnings.
#' @param group Deprecated. Use `by` instead.
#'
#' @return A list of data frames. The list includes one training set per given
#' proportion and the remaining data as test set. List elements of training
Expand All @@ -28,7 +29,7 @@
#' nrow(out$p_0.9)
#'
#' # Stratify by group (equal proportions of each species)
#' out <- data_partition(iris, proportion = 0.9, group = "Species")
#' out <- data_partition(iris, proportion = 0.9, by = "Species")
#' out$test
#'
#' # Create multiple partitions
Expand All @@ -38,21 +39,27 @@
#' # Create multiple partitions, stratified by group - 30% equally sampled
#' # from species in first training set, 50% in second training set and
#' # remaining 20% equally sampled from each species in test set.
#' out <- data_partition(iris, proportion = c(0.3, 0.5), group = "Species")
#' out <- data_partition(iris, proportion = c(0.3, 0.5), by = "Species")
#' lapply(out, function(i) table(i$Species))
#'
#' @inherit data_rename seealso
#' @export
data_partition <- function(data,
proportion = 0.7,
group = NULL,
by = NULL,
seed = NULL,
row_id = ".row_id",
verbose = TRUE,
group = NULL,
...) {
# validation checks
data <- .coerce_to_dataframe(data)

## TODO: deprecate later
if (!is.null(group)) {
by <- group
}

if (sum(proportion) > 1) {
insight::format_error("Sum of `proportion` cannot be higher than 1.")
}
Expand Down Expand Up @@ -91,12 +98,12 @@ data_partition <- function(data,

# Create list of data groups. We generally lapply over list of
# sampled row-id's by group, thus, we even create a list if not grouped.
if (is.null(group)) {
if (is.null(by)) {
indices_list <- list(seq_len(nrow(data)))
} else {
# else, split by group(s) and extract row-ids per group
indices_list <- lapply(
split(data, data[group]),
split(data, data[by]),
data_extract,
select = row_id,
as_data_frame = FALSE
Expand Down Expand Up @@ -130,7 +137,7 @@ data_partition <- function(data,
})

# we need to move all list elements one level higher.
if (is.null(group)) {
if (is.null(by)) {
training_sets <- training_sets[[1]]
} else {
# for grouped training sets, we need to row-bind all sampled training
Expand Down
13 changes: 8 additions & 5 deletions man/data_partition.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-data_partition.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ test_that("data_partition works as expected", {
data(iris)
expect_snapshot(str(data_partition(iris, proportion = 0.7, seed = 123)))
expect_snapshot(str(data_partition(iris, proportion = c(0.2, 0.5), seed = 123)))
expect_snapshot(str(data_partition(iris, proportion = 0.7, group = "Species", seed = 123)))
expect_snapshot(str(data_partition(iris, proportion = c(0.2, 0.5), group = "Species", seed = 123)))
expect_snapshot(str(data_partition(iris, proportion = 0.7, by = "Species", seed = 123)))
expect_snapshot(str(data_partition(iris, proportion = c(0.2, 0.5), by = "Species", seed = 123)))
})

test_that("data_partition warns if no testing set", {
Expand Down

0 comments on commit ea9486e

Please sign in to comment.