diff --git a/crates/polars-expr/src/expressions/sortby.rs b/crates/polars-expr/src/expressions/sortby.rs index 3d9877038adc..b3df9ee2316a 100644 --- a/crates/polars-expr/src/expressions/sortby.rs +++ b/crates/polars-expr/src/expressions/sortby.rs @@ -219,33 +219,58 @@ impl PhysicalExpr for SortByExpr { let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len()); let sorted_idx_f = || { - let s_sort_by = self + let mut needs_broadcast = false; + let mut broadcast_length = 1; + + let mut s_sort_by = self .by .iter() - .map(|e| { - e.evaluate(df, state).map(|s| match s.dtype() { + .enumerate() + .map(|(i, e)| { + let column = e.evaluate(df, state).map(|c| match c.dtype() { #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => s, - _ => s.to_physical_repr(), - }) + DataType::Categorical(_, _) | DataType::Enum(_, _) => c, + _ => c.to_physical_repr(), + })?; + + if column.len() == 1 && broadcast_length != 1 { + polars_ensure!( + e.is_scalar(), + ShapeMismatch: "non-scalar expression produces broadcasting column", + ); + + return Ok(column.new_from_index(0, broadcast_length)); + } + + if broadcast_length != column.len() { + polars_ensure!( + broadcast_length == 1, ShapeMismatch: + "`sort_by` produced different length ({}) than earlier Series' length in `by` ({})", + broadcast_length, column.len() + ); + + needs_broadcast |= i > 0; + broadcast_length = column.len(); + } + + Ok(column) }) .collect::>>()?; + if needs_broadcast { + for c in s_sort_by.iter_mut() { + if c.len() != broadcast_length { + *c = c.new_from_index(0, broadcast_length); + } + } + } + let options = self .sort_options .clone() .with_order_descending_multi(descending) .with_nulls_last_multi(nulls_last); - for i in 1..s_sort_by.len() { - polars_ensure!( - s_sort_by[0].len() == s_sort_by[i].len(), - expr = self.expr, ShapeMismatch: - "`sort_by` produced different length ({}) than earlier Series' length in `by` ({})", - s_sort_by[0].len(), s_sort_by[i].len() - ); - } - s_sort_by[0] .as_materialized_series() .arg_sort_multiple(&s_sort_by[1..], &options) diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 6f90f6744b61..dd06195e4fcd 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -1072,11 +1072,12 @@ def test_sort_string_nulls() -> None: ] -@pytest.mark.may_fail_auto_streaming def test_sort_by_unequal_lengths_7207() -> None: - df = pl.DataFrame({"a": [0, 1, 1, 0], "b": [3, 2, 3, 2]}) - with pytest.raises(pl.exceptions.ShapeError): - df.select(pl.col.a.sort_by(["a", 1])) + df = pl.DataFrame({"a": [0, 1, 1, 0]}) + result = df.select(pl.arg_sort_by(["a", 1])) + + expected = pl.DataFrame({"a": [0, 3, 1, 2]}) + assert_frame_equal(result, expected, check_dtypes=False) def test_sort_literals() -> None: