Skip to content

Commit

Permalink
Fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Feb 21, 2024
1 parent 4f9b26d commit 9be424a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
3 changes: 2 additions & 1 deletion daft/delta_lake/delta_lake_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any
from urllib.parse import urlparse

import pyarrow as pa
from deltalake.table import DeltaTable

import daft
Expand Down Expand Up @@ -58,6 +57,8 @@ def multiline_display(self) -> list[str]:
]

def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
import pyarrow as pa

# TODO(Clark): Push limit and filter expressions into deltalake action fetch, to prune the files returned.
add_actions: pa.RecordBatch = self._table.get_add_actions()

Expand Down
3 changes: 3 additions & 0 deletions tests/io/delta_lake/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def local_deltalake_table(request, tmp_path, partition_by) -> deltalake.DeltaTab
"e": [datetime.datetime(2024, 2, 10), datetime.datetime(2024, 2, 11), datetime.datetime(2024, 2, 12)],
}
)
# Delta Lake casts timestamps to microsecond resolution on ingress, so we preemptively cast the Pandas DataFrame here
# to make equality assertions easier later.
base_df["e"] = base_df["e"].astype("datetime64[us]")
dfs = []
for part_idx in range(request.param):
part_df = base_df.copy()
Expand Down
10 changes: 7 additions & 3 deletions tests/io/delta_lake/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

deltalake = pytest.importorskip("deltalake")

import pyarrow as pa

import daft
from daft.logical.schema import Schema

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="deltalake only supported if pyarrow >= 8.0.0")


@pytest.mark.integration()
def test_deltalake_read_basic(tmp_path):
pd_df = pd.DataFrame(
{
Expand All @@ -24,18 +28,18 @@ def test_deltalake_read_basic(tmp_path):
deltalake.write_deltalake(path, pd_df)
df = daft.read_delta_lake(str(path))
assert df.schema() == Schema.from_pyarrow_schema(deltalake.DeltaTable(path).schema().to_pyarrow())
# Delta Lake casts timestamps to microsecond resolution on ingress.
pd_df["c"] = pd_df["c"].astype("datetime64[us]")
pd.testing.assert_frame_equal(df.to_pandas(), pd_df)


@pytest.mark.integration()
def test_deltalake_read_full(local_deltalake_table):
path, dfs = local_deltalake_table
df = daft.read_delta_lake(str(path))
assert df.schema() == Schema.from_pyarrow_schema(deltalake.DeltaTable(path).schema().to_pyarrow())
pd.testing.assert_frame_equal(df.to_pandas(), pd.concat(dfs).reset_index(drop=True))


@pytest.mark.integration()
def test_deltalake_read_show(local_deltalake_table):
path, _ = local_deltalake_table
df = daft.read_delta_lake(str(path))
Expand Down
11 changes: 5 additions & 6 deletions tests/io/delta_lake/test_table_read_pushdowns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

deltalake = pytest.importorskip("deltalake")

import pyarrow as pa

import daft
from daft.logical.schema import Schema

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="deltalake only supported if pyarrow >= 8.0.0")


@pytest.mark.integration()
def test_deltalake_read_predicate_pushdown_on_data(local_deltalake_table):
path, dfs = local_deltalake_table
df = daft.read_delta_lake(str(path))
Expand All @@ -22,7 +26,6 @@ def test_deltalake_read_predicate_pushdown_on_data(local_deltalake_table):
)


@pytest.mark.integration()
def test_deltalake_read_predicate_pushdown_on_part(local_deltalake_table):
path, dfs = local_deltalake_table
df = daft.read_delta_lake(str(path))
Expand All @@ -33,7 +36,6 @@ def test_deltalake_read_predicate_pushdown_on_part(local_deltalake_table):
)


@pytest.mark.integration()
def test_deltalake_read_predicate_pushdown_on_part_non_eq(local_deltalake_table):
path, dfs = local_deltalake_table
df = daft.read_delta_lake(str(path))
Expand All @@ -44,7 +46,6 @@ def test_deltalake_read_predicate_pushdown_on_part_non_eq(local_deltalake_table)
)


@pytest.mark.integration()
def test_deltalake_read_predicate_pushdown_on_part_and_data(local_deltalake_table):
path, dfs = local_deltalake_table
df = daft.read_delta_lake(str(path))
Expand All @@ -58,7 +59,6 @@ def test_deltalake_read_predicate_pushdown_on_part_and_data(local_deltalake_tabl
)


@pytest.mark.integration()
def test_deltalake_read_predicate_pushdown_on_part_and_data_same_clause(local_deltalake_table):
path, dfs = local_deltalake_table
df = daft.read_delta_lake(str(path))
Expand All @@ -70,7 +70,6 @@ def test_deltalake_read_predicate_pushdown_on_part_and_data_same_clause(local_de
)


@pytest.mark.integration()
def test_deltalake_read_predicate_pushdown_on_part_empty(local_deltalake_table):
path, dfs = local_deltalake_table
df = daft.read_delta_lake(str(path))
Expand Down

0 comments on commit 9be424a

Please sign in to comment.