From 94b79cefdfa36d24aab178c0155971d43af52ed1 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Tue, 24 Dec 2024 10:30:05 +0100 Subject: [PATCH 1/3] fix: Properly broadcast in sort_by --- crates/polars-expr/src/expressions/sortby.rs | 51 ++++++++++++++------ py-polars/tests/unit/operations/test_sort.py | 9 ++-- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/crates/polars-expr/src/expressions/sortby.rs b/crates/polars-expr/src/expressions/sortby.rs index 3d9877038adc..0f4cb05057e4 100644 --- a/crates/polars-expr/src/expressions/sortby.rs +++ b/crates/polars-expr/src/expressions/sortby.rs @@ -219,33 +219,54 @@ impl PhysicalExpr for SortByExpr { let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len()); let sorted_idx_f = || { - let s_sort_by = self + let mut needs_broadcast = false; + let mut broadcast_length = 1; + + let mut s_sort_by = self .by .iter() - .map(|e| { - e.evaluate(df, state).map(|s| match s.dtype() { + .enumerate() + .map(|(i, e)| { + let column = e.evaluate(df, state).map(|c| match c.dtype() { #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => s, - _ => s.to_physical_repr(), - }) + DataType::Categorical(_, _) | DataType::Enum(_, _) => c, + _ => c.to_physical_repr(), + })?; + + if column.len() == 1 && broadcast_length != 1 { + return Ok(column.new_from_index(0, broadcast_length)); + } + + if broadcast_length != column.len() { + polars_ensure!( + broadcast_length == 1, + expr = self.expr, ShapeMismatch: + "`sort_by` produced different length ({}) than earlier Series' length in `by` ({})", + broadcast_length, column.len() + ); + + needs_broadcast |= i > 0; + broadcast_length = column.len(); + } + + Ok(column) }) .collect::>>()?; + if needs_broadcast { + for c in s_sort_by.iter_mut() { + if c.len() != broadcast_length { + *c = c.new_from_index(0, broadcast_length); + } + } + } + let options = self .sort_options .clone() .with_order_descending_multi(descending) .with_nulls_last_multi(nulls_last); - for i in 1..s_sort_by.len() { - polars_ensure!( - s_sort_by[0].len() == s_sort_by[i].len(), - expr = self.expr, ShapeMismatch: - "`sort_by` produced different length ({}) than earlier Series' length in `by` ({})", - s_sort_by[0].len(), s_sort_by[i].len() - ); - } - s_sort_by[0] .as_materialized_series() .arg_sort_multiple(&s_sort_by[1..], &options) diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 6f90f6744b61..dd06195e4fcd 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -1072,11 +1072,12 @@ def test_sort_string_nulls() -> None: ] -@pytest.mark.may_fail_auto_streaming def test_sort_by_unequal_lengths_7207() -> None: - df = pl.DataFrame({"a": [0, 1, 1, 0], "b": [3, 2, 3, 2]}) - with pytest.raises(pl.exceptions.ShapeError): - df.select(pl.col.a.sort_by(["a", 1])) + df = pl.DataFrame({"a": [0, 1, 1, 0]}) + result = df.select(pl.arg_sort_by(["a", 1])) + + expected = pl.DataFrame({"a": [0, 3, 1, 2]}) + assert_frame_equal(result, expected, check_dtypes=False) def test_sort_literals() -> None: From de9b99aafd016d7657854715a15f6451760ccd8e Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Wed, 8 Jan 2025 10:21:46 +0100 Subject: [PATCH 2/3] check e.is_scalar --- crates/polars-expr/src/expressions/sortby.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-expr/src/expressions/sortby.rs b/crates/polars-expr/src/expressions/sortby.rs index 0f4cb05057e4..a60ff81c11c6 100644 --- a/crates/polars-expr/src/expressions/sortby.rs +++ b/crates/polars-expr/src/expressions/sortby.rs @@ -239,7 +239,7 @@ impl PhysicalExpr for SortByExpr { if broadcast_length != column.len() { polars_ensure!( - broadcast_length == 1, + broadcast_length == 1 && e.is_scalar(), expr = self.expr, ShapeMismatch: "`sort_by` produced different length ({}) than earlier Series' length in `by` ({})", broadcast_length, column.len() From 0abe2df7b1ea9831b0b5a15ffe43afd51cea5acc Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Wed, 8 Jan 2025 12:35:17 +0100 Subject: [PATCH 3/3] fix scalar --- crates/polars-expr/src/expressions/sortby.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/crates/polars-expr/src/expressions/sortby.rs b/crates/polars-expr/src/expressions/sortby.rs index a60ff81c11c6..b3df9ee2316a 100644 --- a/crates/polars-expr/src/expressions/sortby.rs +++ b/crates/polars-expr/src/expressions/sortby.rs @@ -234,13 +234,17 @@ impl PhysicalExpr for SortByExpr { })?; if column.len() == 1 && broadcast_length != 1 { + polars_ensure!( + e.is_scalar(), + ShapeMismatch: "non-scalar expression produces broadcasting column", + ); + return Ok(column.new_from_index(0, broadcast_length)); } if broadcast_length != column.len() { polars_ensure!( - broadcast_length == 1 && e.is_scalar(), - expr = self.expr, ShapeMismatch: + broadcast_length == 1, ShapeMismatch: "`sort_by` produced different length ({}) than earlier Series' length in `by` ({})", broadcast_length, column.len() );