diff --git a/tests/integration/iceberg/test_partition_pruning.py b/tests/integration/iceberg/test_partition_pruning.py index 2865858ddc..fafff41624 100644 --- a/tests/integration/iceberg/test_partition_pruning.py +++ b/tests/integration/iceberg/test_partition_pruning.py @@ -6,9 +6,11 @@ import itertools from datetime import date, datetime +import pandas as pd import pytz import daft +from daft.expressions import Expression from tests.conftest import assert_df_equals @@ -27,6 +29,15 @@ def test_daft_iceberg_table_predicate_pushdown_days(local_iceberg_catalog): assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) +def is_null(obj): + if isinstance(obj, Expression): + return obj.is_null() + elif isinstance(obj, pd.Series): + return obj.isnull() + else: + raise NotImplementedError() + + @pytest.mark.integration() @pytest.mark.parametrize( "predicate, table, limit", @@ -40,6 +51,7 @@ def test_daft_iceberg_table_predicate_pushdown_days(local_iceberg_catalog): lambda x: date(2023, 3, 6) == x, lambda x: date(2023, 3, 6) < x, lambda x: date(2023, 3, 6) != x, + is_null, ], [ "test_partitioned_by_months", @@ -77,6 +89,7 @@ def test_daft_iceberg_table_predicate_pushdown_on_date_column(predicate, table, lambda x: datetime(2023, 3, 6, tzinfo=pytz.utc) == x, lambda x: datetime(2023, 3, 6, tzinfo=pytz.utc) < x, lambda x: datetime(2023, 3, 6, tzinfo=pytz.utc) != x, + is_null, ], [ "test_partitioned_by_days", @@ -118,6 +131,7 @@ def test_daft_iceberg_table_predicate_pushdown_on_timestamp_column(predicate, ta lambda x: "d" < x, lambda x: "d" != x, lambda x: "z" == x, + is_null, ], [ "test_partitioned_by_truncate", @@ -157,6 +171,7 @@ def test_daft_iceberg_table_predicate_pushdown_on_letter(predicate, table, limit lambda x: 4 < x, lambda x: 4 != x, lambda x: 100 == x, + is_null, ], [ "test_partitioned_by_bucket",