From b80c16788c4952c234a4de322ab061f8d9d6627b Mon Sep 17 00:00:00 2001 From: ritchie Date: Wed, 8 Jan 2025 10:42:31 +0100 Subject: [PATCH] fix: Fix order observability of group-by-dyn --- .../src/plans/optimizer/set_order.rs | 10 ++-- .../tests/unit/lazyframe/optimizations.py | 37 ------------- .../lazyframe/test_order_observability.py | 53 +++++++++++++++++++ 3 files changed, 58 insertions(+), 42 deletions(-) create mode 100644 py-polars/tests/unit/lazyframe/test_order_observability.py diff --git a/crates/polars-plan/src/plans/optimizer/set_order.rs b/crates/polars-plan/src/plans/optimizer/set_order.rs index b6d360cc56db..a2f71d1568c1 100644 --- a/crates/polars-plan/src/plans/optimizer/set_order.rs +++ b/crates/polars-plan/src/plans/optimizer/set_order.rs @@ -123,11 +123,6 @@ pub(super) fn set_order_flags( .. } => { debug_assert!(options.slice.is_none()); - if !maintain_order_above && *maintain_order { - *maintain_order = false; - continue; - } - if apply.is_some() || *maintain_order || options.is_rolling() @@ -136,6 +131,11 @@ pub(super) fn set_order_flags( maintain_order_above = true; continue; } + if !maintain_order_above && *maintain_order { + *maintain_order = false; + continue; + } + if all_elementwise(keys, expr_arena) && all_order_independent(aggs, expr_arena, Context::Aggregation) { diff --git a/py-polars/tests/unit/lazyframe/optimizations.py b/py-polars/tests/unit/lazyframe/optimizations.py index 969b2be15c8d..8b33cd6a8967 100644 --- a/py-polars/tests/unit/lazyframe/optimizations.py +++ b/py-polars/tests/unit/lazyframe/optimizations.py @@ -4,33 +4,6 @@ from polars.testing import assert_frame_equal -def test_remove_double_sort() -> None: - assert ( - pl.LazyFrame({"a": [1, 2, 3, 3]}).sort("a").sort("a").explain().count("SORT") - == 1 - ) - - -def test_double_sort_maintain_order_18558() -> None: - df = pl.DataFrame( - { - "col1": [1, 2, 2, 4, 5, 6], - "col2": [2, 2, 0, 0, 2, None], - } - ) - - lf = df.lazy().sort("col2").sort("col1", maintain_order=True) - - expect = pl.DataFrame( - [ - pl.Series("col1", [1, 2, 2, 4, 5, 6], dtype=pl.Int64), - pl.Series("col2", [2, 0, 2, 0, 2, None], dtype=pl.Int64), - ] - ) - - assert_frame_equal(lf.collect(), expect) - - def test_fast_count_alias_18581() -> None: f = io.BytesIO() f.write(b"a,b,c\n1,2,3\n4,5,6") @@ -40,13 +13,3 @@ def test_fast_count_alias_18581() -> None: df = pl.scan_csv(f).select(pl.len().alias("weird_name")).collect() assert_frame_equal(pl.DataFrame({"weird_name": 2}), df) - - -def test_order_observability() -> None: - q = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).sort("a") - - assert "SORT" not in q.group_by("a").sum().explain(_check_order=True) - assert "SORT" not in q.group_by("a").min().explain(_check_order=True) - assert "SORT" not in q.group_by("a").max().explain(_check_order=True) - assert "SORT" in q.group_by("a").last().explain(_check_order=True) - assert "SORT" in q.group_by("a").first().explain(_check_order=True) diff --git a/py-polars/tests/unit/lazyframe/test_order_observability.py b/py-polars/tests/unit/lazyframe/test_order_observability.py new file mode 100644 index 000000000000..425a4954b0da --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_order_observability.py @@ -0,0 +1,53 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_order_observability() -> None: + q = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 2, 3]}).sort("a") + + assert "SORT" not in q.group_by("a").sum().explain(_check_order=True) + assert "SORT" not in q.group_by("a").min().explain(_check_order=True) + assert "SORT" not in q.group_by("a").max().explain(_check_order=True) + assert "SORT" in q.group_by("a").last().explain(_check_order=True) + assert "SORT" in q.group_by("a").first().explain(_check_order=True) + + +def test_order_observability_group_by_dynamic() -> None: + assert ( + pl.LazyFrame( + {"REGIONID": [1, 23, 4], "INTERVAL_END": [32, 43, 12], "POWER": [12, 3, 1]} + ) + .sort("REGIONID", "INTERVAL_END") + .group_by_dynamic(index_column="INTERVAL_END", every="1i", group_by="REGIONID") + .agg(pl.col("POWER").sum()) + .sort("POWER") + .head() + .explain() + ).count("SORT") == 2 + + +def test_remove_double_sort() -> None: + assert ( + pl.LazyFrame({"a": [1, 2, 3, 3]}).sort("a").sort("a").explain().count("SORT") + == 1 + ) + + +def test_double_sort_maintain_order_18558() -> None: + df = pl.DataFrame( + { + "col1": [1, 2, 2, 4, 5, 6], + "col2": [2, 2, 0, 0, 2, None], + } + ) + + lf = df.lazy().sort("col2").sort("col1", maintain_order=True) + + expect = pl.DataFrame( + [ + pl.Series("col1", [1, 2, 2, 4, 5, 6], dtype=pl.Int64), + pl.Series("col2", [2, 0, 2, 0, 2, None], dtype=pl.Int64), + ] + ) + + assert_frame_equal(lf.collect(), expect)