diff --git a/R/dplyr-rows.R b/R/dplyr-rows.R index ee8349a..8bfe177 100644 --- a/R/dplyr-rows.R +++ b/R/dplyr-rows.R @@ -46,7 +46,7 @@ timbr_rows <- function(f, x, y, by, ...) { tibble::as_tibble() |> dplyr::ungroup() |> f(y, by, ...) |> - dplyr::select(!dplyr::all_of(by)) + dplyr::select(!dplyr::all_of(c(group_vars(x), by))) x$graph <- x$graph |> tidygraph::activate("nodes") |> diff --git a/tests/testthat/test-dplyr.R b/tests/testthat/test-dplyr.R index 9aff859..d7643aa 100644 --- a/tests/testthat/test-dplyr.R +++ b/tests/testthat/test-dplyr.R @@ -48,6 +48,8 @@ test_that("dplyr", { }) test_that("rows_update", { + set.seed(1234) + library(dplyr) fr <- vec_expand_grid(key1 = letters[1:3], @@ -88,4 +90,25 @@ test_that("rows_update", { rows_patch(df, by = c("key1", "key2")) %>% children()) + + fr <- vec_expand_grid(key1 = letters[1:3], + key2 = letters[1:3], + key3 = letters[1:3]) %>% + mutate(value = row_number()) %>% + forest_by(key1, key2, key3) %>% + summarise(value = sum(value)) + df <- vec_expand_grid(key2 = c("c", "a"), + key3 = c("a", "b", "c")) %>% + mutate(value = sample(1:9, n())) + fr <- fr %>% + rows_update(df, + by = c("key2", "key3")) %>% + climb(key2, key3) %>% + as_tibble() %>% + rename(value_object = value) %>% + inner_join(df %>% + rename(value_expected = value), + by = join_by(key2, key3)) + + expect_equal(fr$value_object, fr$value_expected) })