Skip to content

Commit

Permalink
working tests for day and hour
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Dec 21, 2023
1 parent 262296a commit 4086bb6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 14 deletions.
5 changes: 2 additions & 3 deletions src/daft-core/src/series/ops/partitioning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl Series {
}

pub fn partitioning_days(&self) -> DaftResult<Self> {
let value = match self.data_type() {
match self.data_type() {
DataType::Date => Ok(self.clone()),
DataType::Timestamp(_, None) => {
let ts_array = self.downcast::<TimestampArray>()?;
Expand All @@ -67,8 +67,7 @@ impl Series {
"Can only run partitioning_days() operation on temporal types, got {}",
self.data_type()
))),
}?;
value.cast(&DataType::Int32)
}
}

pub fn partitioning_hours(&self) -> DaftResult<Self> {
Expand Down
75 changes: 64 additions & 11 deletions tests/integration/iceberg/test_table_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import pytest

pyiceberg = pytest.importorskip("pyiceberg")
import itertools
from datetime import date, datetime

import pytz
from pyiceberg.io.pyarrow import schema_to_pyarrow

import daft
Expand Down Expand Up @@ -60,32 +64,81 @@ def test_daft_iceberg_table_collect_correct(table_name, local_iceberg_catalog):

@pytest.mark.integration()
def test_daft_iceberg_table_predicate_pushdown_days(local_iceberg_catalog):
from datetime import date

tab = local_iceberg_catalog.load_table("default.test_partitioned_by_days")
df = daft.read_iceberg(tab)
df = df.where(df["ts"] < date(2023, 3, 6))
df.collect()
import ipdb

ipdb.set_trace()
daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


@pytest.mark.integration()
def test_daft_iceberg_table_predicate_pushdown_months(local_iceberg_catalog):
from datetime import date

tab = local_iceberg_catalog.load_table("default.test_partitioned_by_months")
@pytest.mark.parametrize(
"predicate, table",
itertools.product(
[
lambda x: x < date(2023, 3, 6),
lambda x: x == date(2023, 3, 6),
lambda x: x > date(2023, 3, 6),
lambda x: x != date(2023, 3, 6),
lambda x: x == date(2022, 3, 6),
],
[
"test_partitioned_by_months",
"test_partitioned_by_years",
],
),
)
def test_daft_iceberg_table_predicate_pushdown_on_date_column(predicate, table, local_iceberg_catalog):
tab = local_iceberg_catalog.load_table(f"default.{table}")
df = daft.read_iceberg(tab)
df = df.where(df["dt"] > date(2025, 1, 1))
import ipdb
df = df.where(predicate(df["dt"]))
df.explain(True)
df.collect()

daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["dt"])]
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])

ipdb.set_trace()

#
@pytest.mark.integration()
@pytest.mark.parametrize(
"predicate, table",
itertools.product(
[
lambda x: x < datetime(2023, 3, 6, tzinfo=pytz.utc),
lambda x: x == datetime(2023, 3, 6, tzinfo=pytz.utc),
lambda x: x > datetime(2023, 3, 6, tzinfo=pytz.utc),
lambda x: x != datetime(2023, 3, 6, tzinfo=pytz.utc),
lambda x: x == datetime(2022, 3, 6, tzinfo=pytz.utc),
],
[
"test_partitioned_by_days",
"test_partitioned_by_hours",
],
),
)
def test_daft_iceberg_table_predicate_pushdown_on_timestamp_column(predicate, table, local_iceberg_catalog):
tab = local_iceberg_catalog.load_table(f"default.{table}")
df = daft.read_iceberg(tab)
df = df.where(predicate(df["ts"]))
df.collect()

daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["ts"])]
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


@pytest.mark.integration()
def test_daft_iceberg_table_predicate_pushdown_empty_scan(local_iceberg_catalog):
tab = local_iceberg_catalog.load_table("default.test_partitioned_by_months")
df = daft.read_iceberg(tab)
df = df.where(df["dt"] > date(2030, 1, 1))
df.collect()
values = df.to_arrow()
assert len(values) == 0

0 comments on commit 4086bb6

Please sign in to comment.