From 06b00a6e977f6d093e2ad86f94d14366ff463330 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Fri, 24 May 2024 15:40:02 -0700 Subject: [PATCH] [FEAT] Delta Lake Writer (non-public API) (#2304) Continuation of work by @siddharth-gulia in #2073 on our Delta Lake writer --------- Co-authored-by: siddharth kumar --- daft/context.py | 2 + daft/daft.pyi | 21 ++++ daft/dataframe/dataframe.py | 111 +++++++++++++++++- .../delta_lake/delta_lake_storage_function.py | 78 ++++++++++++ daft/execution/execution_step.py | 36 ++++++ daft/execution/physical_plan.py | 34 +++++- daft/execution/rust_physical_plan_shim.py | 16 +++ daft/logical/builder.py | 19 +++ daft/table/table_io.py | 89 ++++++++++++++ requirements-dev.txt | 1 - src/common/daft-config/src/lib.rs | 8 ++ src/common/daft-config/src/python.rs | 36 ++++++ src/daft-plan/src/builder.rs | 51 ++++++++ src/daft-plan/src/logical_ops/sink.rs | 4 + .../src/physical_ops/deltalake_write.rs | 33 ++++++ src/daft-plan/src/physical_ops/mod.rs | 4 + src/daft-plan/src/physical_plan.rs | 50 ++++++++ .../src/physical_planner/translate.rs | 8 ++ src/daft-plan/src/sink_info.rs | 52 ++++++++ tests/io/delta_lake/conftest.py | 12 +- tests/io/delta_lake/test_table_read.py | 19 +-- .../delta_lake/test_table_read_pushdowns.py | 13 +- tests/io/delta_lake/test_table_write.py | 66 +++++++++++ 23 files changed, 748 insertions(+), 15 deletions(-) create mode 100644 daft/delta_lake/delta_lake_storage_function.py create mode 100644 src/daft-plan/src/physical_ops/deltalake_write.rs create mode 100644 tests/io/delta_lake/test_table_write.py diff --git a/daft/context.py b/daft/context.py index f24a23b86a..7c765eb068 100644 --- a/daft/context.py +++ b/daft/context.py @@ -292,6 +292,8 @@ def set_execution_config( ctx = get_context() with ctx._lock: old_daft_execution_config = ctx._daft_execution_config if config is None else config + + # TODO: Re-addd Parquet configs when we are ready to support Delta Lake writes new_daft_execution_config = old_daft_execution_config.with_config_values( scan_tasks_min_size_bytes=scan_tasks_min_size_bytes, scan_tasks_max_size_bytes=scan_tasks_max_size_bytes, diff --git a/daft/daft.pyi b/daft/daft.pyi index 1a0dc5e20d..9f4f3bfb7d 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1406,6 +1406,15 @@ class LogicalPlanBuilder: catalog_columns: list[str], io_config: IOConfig | None = None, ) -> LogicalPlanBuilder: ... + def delta_write( + self, + path: str, + columns_name: list[str], + mode: str, + current_version: int, + large_dtypes: bool, + io_config: IOConfig | None = None, + ) -> LogicalPlanBuilder: ... def schema(self) -> PySchema: ... def optimize(self) -> LogicalPlanBuilder: ... def to_physical_plan_scheduler(self, cfg: PyDaftExecutionConfig) -> PhysicalPlanScheduler: ... @@ -1426,6 +1435,10 @@ class PyDaftExecutionConfig: num_preview_rows: int | None = None, parquet_target_filesize: int | None = None, parquet_target_row_group_size: int | None = None, + parquet_max_open_files: int | None = None, + parquet_max_rows_per_file: int | None = None, + parquet_min_rows_per_group: int | None = None, + parquet_max_rows_per_group: int | None = None, parquet_inflation_factor: float | None = None, csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, @@ -1450,6 +1463,14 @@ class PyDaftExecutionConfig: @property def parquet_target_row_group_size(self) -> int: ... @property + def parquet_max_open_files(self) -> int: ... + @property + def parquet_max_rows_per_file(self) -> int: ... + @property + def parquet_min_rows_per_group(self) -> int: ... + @property + def parquet_max_rows_per_group(self) -> int: ... + @property def parquet_inflation_factor(self) -> float: ... @property def csv_target_filesize(self) -> int: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 7ff6715a52..4c5ad3ae72 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -27,9 +27,20 @@ from daft.api_annotations import DataframePublicAPI from daft.context import get_context from daft.convert import InputListType -from daft.daft import FileFormat, IOConfig, JoinStrategy, JoinType, ResourceRequest +from daft.daft import ( + FileFormat, + IOConfig, + JoinStrategy, + JoinType, + NativeStorageConfig, + ResourceRequest, + StorageConfig, +) from daft.dataframe.preview import DataFramePreview from daft.datatype import DataType +from daft.delta_lake.delta_lake_storage_function import ( + _storage_config_to_storage_options, +) from daft.errors import ExpressionTypeError from daft.expressions import Expression, ExpressionsProjection, col, lit from daft.logical.builder import LogicalPlanBuilder @@ -534,6 +545,104 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> # This is due to the fact that the logical plan of the write_iceberg returns datafiles but we want to return the above data return with_operations + def write_delta( + self, + path: str, + mode: str = "append", + io_config: Optional[IOConfig] = None, + ) -> None: + import deltalake + import pyarrow as pa + from deltalake.schema import _convert_pa_schema_to_delta + from deltalake.writer import ( + try_get_table_and_table_uri, + write_deltalake_pyarrow, + ) + from packaging.version import parse + + if mode not in ["append"]: + raise ValueError(f"Mode {mode} is not supported. Only 'append' mode is supported") + + if parse(deltalake.__version__) < parse("0.14.0"): + raise ValueError(f"Write delta lake is only supported on deltalake>=0.14.0, found {deltalake.__version__}") + + io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config + storage_config = StorageConfig.native(NativeStorageConfig(False, io_config)) + storage_options = _storage_config_to_storage_options(storage_config, path) + table, table_uri = try_get_table_and_table_uri(path, storage_options) + if table is not None: + storage_options = table._storage_options or {} + storage_options.update(storage_options or {}) + + table.update_incremental() + + fields = [f for f in self.schema()] + pyarrow_fields = [pa.field(f.name, f.dtype.to_arrow_dtype()) for f in fields] + pyarrow_schema = pa.schema(pyarrow_fields) + + delta_schema = _convert_pa_schema_to_delta(pyarrow_schema, large_dtypes=True) + if table: + if delta_schema != table.schema().to_pyarrow(as_large_types=True): + raise ValueError( + "Schema of data does not match table schema\n" + f"Data schema:\n{delta_schema}\nTable Schema:\n{table.schema().to_pyarrow(as_large_types=True)}" + ) + if mode == "error": + raise AssertionError("DeltaTable already exists.") + elif mode == "ignore": + return + + current_version = table.version() + + else: + current_version = -1 + + builder = self._builder.write_delta( + path=path, + mode=mode, + current_version=current_version, + large_dtypes=True, + io_config=io_config, + ) + write_df = DataFrame(builder) + write_df.collect() + + write_result = write_df.to_pydict() + assert "data_file" in write_result + data_files = write_result["data_file"] + add_action = [] + + operations = [] + respath = [] + size = [] + + for data_file in data_files: + operations.append("ADD") + respath.append(data_file.path) + size.append(data_file.size) + add_action.append(data_file) + + if table is None: + write_deltalake_pyarrow( + table_uri, + delta_schema, + add_action, + mode, + [], + storage_options=storage_options, + ) + else: + table._table.create_write_transaction( + add_action, + mode, + [], + delta_schema, + None, + ) + table.update_incremental() + + return None + ### # DataFrame operations ### diff --git a/daft/delta_lake/delta_lake_storage_function.py b/daft/delta_lake/delta_lake_storage_function.py new file mode 100644 index 0000000000..f902cdcb08 --- /dev/null +++ b/daft/delta_lake/delta_lake_storage_function.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any +from urllib.parse import urlparse + +from daft.daft import ( + AzureConfig, + GCSConfig, + IOConfig, + NativeStorageConfig, + S3Config, + StorageConfig, +) + + +def _storage_config_to_storage_options(storage_config: StorageConfig, table_uri: str) -> dict[str, str] | None: + """ + Converts the Daft storage config to a storage options dict that deltalake/object_store + understands. + """ + config = storage_config.config + assert isinstance(config, NativeStorageConfig) + io_config = config.io_config + return _io_config_to_storage_options(io_config, table_uri) + + +def _io_config_to_storage_options(io_config: IOConfig, table_uri: str) -> dict[str, str] | None: + scheme = urlparse(table_uri).scheme + if scheme == "s3" or scheme == "s3a": + return _s3_config_to_storage_options(io_config.s3) + elif scheme == "gcs" or scheme == "gs": + return _gcs_config_to_storage_options(io_config.gcs) + elif scheme == "az" or scheme == "abfs": + return _azure_config_to_storage_options(io_config.azure) + else: + return None + + +def _s3_config_to_storage_options(s3_config: S3Config) -> dict[str, str]: + storage_options: dict[str, Any] = {} + if s3_config.region_name is not None: + storage_options["region"] = s3_config.region_name + if s3_config.endpoint_url is not None: + storage_options["endpoint_url"] = s3_config.endpoint_url + if s3_config.key_id is not None: + storage_options["access_key_id"] = s3_config.key_id + if s3_config.session_token is not None: + storage_options["session_token"] = s3_config.session_token + if s3_config.access_key is not None: + storage_options["secret_access_key"] = s3_config.access_key + if s3_config.use_ssl is not None: + storage_options["allow_http"] = "false" if s3_config.use_ssl else "true" + if s3_config.verify_ssl is not None: + storage_options["allow_invalid_certificates"] = "false" if s3_config.verify_ssl else "true" + if s3_config.connect_timeout_ms is not None: + storage_options["connect_timeout"] = str(s3_config.connect_timeout_ms) + "ms" + if s3_config.anonymous: + raise ValueError( + "Reading from DeltaLake does not support anonymous mode! Please supply credentials via your S3Config." + ) + return storage_options + + +def _azure_config_to_storage_options(azure_config: AzureConfig) -> dict[str, str]: + storage_options = {} + if azure_config.storage_account is not None: + storage_options["account_name"] = azure_config.storage_account + if azure_config.access_key is not None: + storage_options["access_key"] = azure_config.access_key + if azure_config.endpoint_url is not None: + storage_options["endpoint"] = azure_config.endpoint_url + if azure_config.use_ssl is not None: + storage_options["allow_http"] = "false" if azure_config.use_ssl else "true" + return storage_options + + +def _gcs_config_to_storage_options(_: GCSConfig) -> dict[str, str]: + return {} diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index cbc7056f94..6b280eaee4 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -408,6 +408,42 @@ def _handle_file_write(self, input: MicroPartition) -> MicroPartition: ) +@dataclass(frozen=True) +class WriteDeltaLake(SingleOutputInstruction): + base_path: str + large_dtypes: bool + current_version: int + io_config: IOConfig | None + + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + return self._write_deltalake(inputs) + + def _write_deltalake(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + [input] = inputs + partition = self._handle_file_write( + input=input, + ) + return [partition] + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + assert len(input_metadatas) == 1 + return [ + PartialPartitionMetadata( + num_rows=None, # we can write more than 1 file per partition + size_bytes=None, + ) + ] + + def _handle_file_write(self, input: MicroPartition) -> MicroPartition: + return table_io.write_deltalake( + input, + large_dtypes=self.large_dtypes, + base_path=self.base_path, + current_version=self.current_version, + io_config=self.io_config, + ) + + @dataclass(frozen=True) class Filter(SingleOutputInstruction): predicate: ExpressionsProjection diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index d0019d9f70..bb4efc4eb5 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -19,7 +19,15 @@ import math import pathlib from collections import deque -from typing import TYPE_CHECKING, Generator, Generic, Iterable, Iterator, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Generator, + Generic, + Iterable, + Iterator, + TypeVar, + Union, +) from daft.context import get_context from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest @@ -130,6 +138,30 @@ def iceberg_write( ) +def deltalake_write( + child_plan: InProgressPhysicalPlan[PartitionT], + base_path: str, + large_dtypes: bool, + current_version: int, + io_config: IOConfig | None, +) -> InProgressPhysicalPlan[PartitionT]: + """Write the results of `child_plan` into pyiceberg data files described by `write_info`.""" + + yield from ( + step.add_instruction( + execution_step.WriteDeltaLake( + base_path=base_path, + large_dtypes=large_dtypes, + current_version=current_version, + io_config=io_config, + ), + ) + if isinstance(step, PartitionTaskBuilder) + else step + for step in child_plan + ) + + def pipeline_instruction( child_plan: InProgressPhysicalPlan[PartitionT], pipeable_instruction: Instruction, diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 2511bffcb0..39a89d1e5b 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -328,3 +328,19 @@ def write_iceberg( spec_id=spec_id, io_config=io_config, ) + + +def write_deltalake( + input: physical_plan.InProgressPhysicalPlan[PartitionT], + path: str, + large_dtypes: bool, + current_version: int, + io_config: IOConfig | None, +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + return physical_plan.deltalake_write( + input, + path, + large_dtypes, + current_version, + io_config, + ) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 4a757a05f1..cf81075393 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -244,3 +244,22 @@ def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder: io_config = _convert_iceberg_file_io_properties_to_io_config(table.io.properties) builder = self._builder.iceberg_write(name, location, spec_id, schema, props, columns, io_config) return LogicalPlanBuilder(builder) + + def write_delta( + self, + path: str | pathlib.Path, + mode: str, + current_version: int, + large_dtypes: bool, + io_config: IOConfig, + ) -> LogicalPlanBuilder: + columns_name = self.schema().column_names() + builder = self._builder.delta_write( + str(path), + columns_name, + mode, + current_version, + large_dtypes, + io_config, + ) + return LogicalPlanBuilder(builder) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 8412e62843..7af6c54c01 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -648,6 +648,95 @@ def file_visitor(written_file, protocol=protocol): return MicroPartition.from_pydict({"data_file": Series.from_pylist(data_files, name="data_file", pyobj="force")}) +def write_deltalake( + mp: MicroPartition, + large_dtypes: bool, + base_path: str, + current_version: int, + io_config: IOConfig | None = None, +): + from deltalake.schema import convert_pyarrow_table + from deltalake.writer import ( + AddAction, + DeltaJSONEncoder, + DeltaStorageHandler, + get_file_stats_from_metadata, + get_partitions_from_path, + try_get_table_and_table_uri, + ) + from pyarrow.fs import PyFileSystem + + from daft.delta_lake.delta_lake_storage_function import ( + _storage_config_to_storage_options, + ) + + data_files: list[AddAction] = [] + + def file_visitor(written_file: Any) -> None: + path, partition_values = get_partitions_from_path(written_file.path) + stats = get_file_stats_from_metadata(written_file.metadata) + + import json + from datetime import datetime + + from daft.utils import ARROW_VERSION + + # PyArrow added support for written_file.size in 9.0.0 + if ARROW_VERSION >= (9, 0, 0): + size = written_file.size + elif filesystem is not None: + size = filesystem.get_file_info([path])[0].size + else: + size = 0 + + data_files.append( + AddAction( + path, + size, + partition_values, + int(datetime.now().timestamp() * 1000), + True, + json.dumps(stats, cls=DeltaJSONEncoder), + ) + ) + + io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config + storage_config = StorageConfig.native(NativeStorageConfig(False, io_config)) + storage_options = _storage_config_to_storage_options(storage_config, base_path) + table, table_uri = try_get_table_and_table_uri(base_path, storage_options) + filesystem = PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) + + arrow_table = mp.to_arrow() + arrow_batch = convert_pyarrow_table(arrow_table, large_dtypes) + + execution_config = get_context().daft_execution_config + MAX_OPEN_FILE = execution_config.parquet_max_open_files + MAX_ROWS_PER_FILE = execution_config.parquet_max_rows_per_file + MIN_ROWS_PER_GROUP = execution_config.parquet_min_rows_per_group + MAX_ROWS_PER_GROUP = execution_config.parquet_max_rows_per_group + + file_options = pads.ParquetFileFormat().make_write_options(use_compliant_nested_type=False) + + pads.write_dataset( + arrow_batch, + base_dir="/", + basename_template=f"{current_version + 1}-{uuid4()}-{{i}}.parquet", + format="parquet", + partitioning=None, + schema=None, + file_visitor=file_visitor, + existing_data_behavior="overwrite_or_ignore", + file_options=file_options, + max_open_files=MAX_OPEN_FILE, + max_rows_per_file=MAX_ROWS_PER_FILE, + min_rows_per_group=MIN_ROWS_PER_GROUP, + max_rows_per_group=MAX_ROWS_PER_GROUP, + filesystem=filesystem, + ) + + return MicroPartition.from_pydict({"data_file": Series.from_pylist(data_files, name="data_file", pyobj="force")}) + + def _write_tabular_arrow_table( arrow_table: pa.Table, schema: pa.Schema | None, diff --git a/requirements-dev.txt b/requirements-dev.txt index 4df6a1d95c..4ac4367d9b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -49,7 +49,6 @@ tenacity==8.2.3; python_version >= '3.8' # Delta Lake deltalake==0.5.8; platform_system == "Windows" deltalake==0.15.3; platform_system != "Windows" and python_version >= '3.8' -deltalake==0.13.0; platform_system != "Windows" and python_version < '3.8' # Databricks databricks-sdk==0.12.0 diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 2f95d58c71..7c60a75597 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -32,6 +32,10 @@ pub struct DaftExecutionConfig { pub num_preview_rows: usize, pub parquet_target_filesize: usize, pub parquet_target_row_group_size: usize, + pub parquet_max_open_files: usize, + pub parquet_max_rows_per_file: usize, + pub parquet_min_rows_per_group: usize, + pub parquet_max_rows_per_group: usize, pub parquet_inflation_factor: f64, pub csv_target_filesize: usize, pub csv_inflation_factor: f64, @@ -52,6 +56,10 @@ impl Default for DaftExecutionConfig { num_preview_rows: 8, parquet_target_filesize: 512 * 1024 * 1024, // 512MB parquet_target_row_group_size: 128 * 1024 * 1024, // 128MB + parquet_max_open_files: 1024, + parquet_max_rows_per_file: 10 * 1024 * 1024, + parquet_min_rows_per_group: 64 * 1024, + parquet_max_rows_per_group: 128 * 1024, parquet_inflation_factor: 3.0, csv_target_filesize: 512 * 1024 * 1024, // 512MB csv_inflation_factor: 0.5, diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 22a374e660..fa9f644535 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -94,6 +94,10 @@ impl PyDaftExecutionConfig { num_preview_rows: Option, parquet_target_filesize: Option, parquet_target_row_group_size: Option, + parquet_max_open_files: Option, + parquet_max_rows_per_file: Option, + parquet_min_rows_per_group: Option, + parquet_max_rows_per_group: Option, parquet_inflation_factor: Option, csv_target_filesize: Option, csv_inflation_factor: Option, @@ -133,6 +137,18 @@ impl PyDaftExecutionConfig { if let Some(parquet_target_row_group_size) = parquet_target_row_group_size { config.parquet_target_row_group_size = parquet_target_row_group_size; } + if let Some(parquet_max_open_files) = parquet_max_open_files { + config.parquet_max_open_files = parquet_max_open_files + } + if let Some(parquet_max_rows_per_file) = parquet_max_rows_per_file { + config.parquet_max_rows_per_file = parquet_max_rows_per_file; + } + if let Some(parquet_min_rows_per_group) = parquet_min_rows_per_group { + config.parquet_min_rows_per_group = parquet_min_rows_per_group; + } + if let Some(parquet_max_rows_per_group) = parquet_max_rows_per_group { + config.parquet_max_rows_per_group = parquet_max_rows_per_group; + } if let Some(parquet_inflation_factor) = parquet_inflation_factor { config.parquet_inflation_factor = parquet_inflation_factor; } @@ -199,6 +215,26 @@ impl PyDaftExecutionConfig { Ok(self.config.parquet_target_row_group_size) } + #[getter] + fn get_parquet_max_open_files(&self) -> PyResult { + Ok(self.config.parquet_max_open_files) + } + + #[getter] + fn get_parquet_max_rows_per_file(&self) -> PyResult { + Ok(self.config.parquet_max_rows_per_file) + } + + #[getter] + fn get_parquet_min_rows_per_group(&self) -> PyResult { + Ok(self.config.parquet_min_rows_per_group) + } + + #[getter] + fn get_parquet_max_rows_per_group(&self) -> PyResult { + Ok(self.config.parquet_max_rows_per_group) + } + #[getter] fn get_parquet_inflation_factor(&self) -> PyResult { Ok(self.config.parquet_inflation_factor) diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index d548e32189..0565b88446 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -503,6 +503,34 @@ impl LogicalPlanBuilder { Ok(logical_plan.into()) } + #[cfg(feature = "python")] + #[allow(clippy::too_many_arguments)] + pub fn delta_write( + &self, + path: String, + columns_name: Vec, + mode: String, + current_version: i32, + large_dtypes: bool, + io_config: Option, + ) -> DaftResult { + use crate::sink_info::DeltaLakeCatalogInfo; + let sink_info = SinkInfo::CatalogInfo(CatalogInfo { + catalog: crate::sink_info::CatalogType::DeltaLake(DeltaLakeCatalogInfo { + path, + mode, + current_version, + large_dtypes, + io_config, + }), + catalog_columns: columns_name, + }); + + let logical_plan: LogicalPlan = + logical_ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); + Ok(logical_plan.into()) + } + pub fn build(&self) -> Arc { self.plan.clone() } @@ -770,6 +798,29 @@ impl PyLogicalPlanBuilder { .into()) } + #[allow(clippy::too_many_arguments)] + pub fn delta_write( + &self, + path: String, + columns_name: Vec, + mode: String, + current_version: i32, + large_dtypes: bool, + io_config: Option, + ) -> PyResult { + Ok(self + .builder + .delta_write( + path, + columns_name, + mode, + current_version, + large_dtypes, + io_config.map(|cfg| cfg.config), + )? + .into()) + } + pub fn schema(&self) -> PyResult { Ok(self.builder.schema().into()) } diff --git a/src/daft-plan/src/logical_ops/sink.rs b/src/daft-plan/src/logical_ops/sink.rs index af38c690d4..b9e347bfa5 100644 --- a/src/daft-plan/src/logical_ops/sink.rs +++ b/src/daft-plan/src/logical_ops/sink.rs @@ -61,6 +61,10 @@ impl Sink { res.push(format!("Sink: Iceberg({})", iceberg_info.table_name)); res.extend(iceberg_info.multiline_display()); } + crate::sink_info::CatalogType::DeltaLake(deltalake_info) => { + res.push(format!("Sink: DeltaLake({})", deltalake_info.path)); + res.extend(deltalake_info.multiline_display()); + } }, } res.push(format!("Output schema = {}", self.schema.short_string())); diff --git a/src/daft-plan/src/physical_ops/deltalake_write.rs b/src/daft-plan/src/physical_ops/deltalake_write.rs new file mode 100644 index 0000000000..916f562ee4 --- /dev/null +++ b/src/daft-plan/src/physical_ops/deltalake_write.rs @@ -0,0 +1,33 @@ +use daft_core::schema::SchemaRef; + +use crate::{physical_plan::PhysicalPlanRef, sink_info::DeltaLakeCatalogInfo}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct DeltaLakeWrite { + pub schema: SchemaRef, + pub delta_lake_info: DeltaLakeCatalogInfo, + // Upstream node. + pub input: PhysicalPlanRef, +} + +impl DeltaLakeWrite { + pub(crate) fn new( + schema: SchemaRef, + delta_lake_info: DeltaLakeCatalogInfo, + input: PhysicalPlanRef, + ) -> Self { + Self { + schema, + delta_lake_info, + input, + } + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push("DeltaLakeWrite:".to_string()); + res.extend(self.delta_lake_info.multiline_display()); + res + } +} diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index beafaf94b6..abffcaacbf 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -3,6 +3,8 @@ mod broadcast_join; mod coalesce; mod concat; mod csv; +#[cfg(feature = "python")] +mod deltalake_write; mod empty_scan; mod explode; mod fanout; @@ -31,6 +33,8 @@ pub use broadcast_join::BroadcastJoin; pub use coalesce::Coalesce; pub use concat::Concat; pub use csv::TabularWriteCsv; +#[cfg(feature = "python")] +pub use deltalake_write::DeltaLakeWrite; pub use empty_scan::EmptyScan; pub use explode::Explode; pub use fanout::{FanoutByHash, FanoutByRange, FanoutRandom}; diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 9c597b4203..acc074556b 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -32,6 +32,9 @@ use crate::{ #[cfg(feature = "python")] use crate::sink_info::IcebergCatalogInfo; +#[cfg(feature = "python")] +use crate::sink_info::DeltaLakeCatalogInfo; + pub(crate) type PhysicalPlanRef = Arc; /// Physical plan for a Daft query. @@ -67,6 +70,8 @@ pub enum PhysicalPlan { TabularWriteCsv(TabularWriteCsv), #[cfg(feature = "python")] IcebergWrite(IcebergWrite), + #[cfg(feature = "python")] + DeltaLakeWrite(DeltaLakeWrite), } pub struct ApproxStats { @@ -258,6 +263,10 @@ impl PhysicalPlan { Self::IcebergWrite(..) => { ClusteringSpec::Unknown(UnknownClusteringConfig::new(1)).into() } + #[cfg(feature = "python")] + Self::DeltaLakeWrite(DeltaLakeWrite { .. }) => { + ClusteringSpec::Unknown(UnknownClusteringConfig::new(1)).into() + } } } @@ -423,6 +432,8 @@ impl PhysicalPlan { } #[cfg(feature = "python")] Self::IcebergWrite(_) => ApproxStats::empty(), + #[cfg(feature = "python")] + Self::DeltaLakeWrite(_) => ApproxStats::empty(), } } @@ -451,6 +462,8 @@ impl PhysicalPlan { Self::TabularWriteJson(TabularWriteJson { input, .. }) => vec![input.clone()], #[cfg(feature = "python")] Self::IcebergWrite(IcebergWrite { input, .. }) => vec![input.clone()], + #[cfg(feature = "python")] + Self::DeltaLakeWrite(DeltaLakeWrite { input, .. }) => vec![input.clone()], Self::HashJoin(HashJoin { left, right, .. }) => vec![left.clone(), right.clone()], Self::BroadcastJoin(BroadcastJoin { broadcaster, @@ -496,6 +509,8 @@ impl PhysicalPlan { Self::TabularWriteJson(TabularWriteJson { schema, file_info, .. }) => Self::TabularWriteJson(TabularWriteJson::new(schema.clone(), file_info.clone(), input.clone())), #[cfg(feature = "python")] Self::IcebergWrite(IcebergWrite { schema, iceberg_info, .. }) => Self::IcebergWrite(IcebergWrite::new(schema.clone(), iceberg_info.clone(), input.clone())), + #[cfg(feature = "python")] + Self::DeltaLakeWrite(DeltaLakeWrite {schema, delta_lake_info, .. }) => Self::DeltaLakeWrite(DeltaLakeWrite::new(schema.clone(), delta_lake_info.clone(), input.clone())), _ => panic!("Physical op {:?} has two inputs, but got one", self), }, [input1, input2] => match self { @@ -550,6 +565,8 @@ impl PhysicalPlan { Self::MonotonicallyIncreasingId(..) => "MonotonicallyIncreasingId", #[cfg(feature = "python")] Self::IcebergWrite(..) => "IcebergWrite", + #[cfg(feature = "python")] + Self::DeltaLakeWrite(..) => "DeltaLakeWrite", }; name.to_string() } @@ -589,6 +606,8 @@ impl PhysicalPlan { } #[cfg(feature = "python")] Self::IcebergWrite(iceberg_info) => iceberg_info.multiline_display(), + #[cfg(feature = "python")] + Self::DeltaLakeWrite(delta_lake_info) => delta_lake_info.multiline_display(), } } @@ -718,6 +737,31 @@ fn iceberg_write( Ok(py_iter.into()) } +#[allow(clippy::too_many_arguments)] +#[cfg(feature = "python")] +fn deltalake_write( + py: Python<'_>, + upstream_iter: PyObject, + delta_lake_info: &DeltaLakeCatalogInfo, +) -> PyResult { + let py_iter = py + .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "write_deltalake"))? + .call1(( + upstream_iter, + &delta_lake_info.path, + delta_lake_info.large_dtypes, + delta_lake_info.current_version, + delta_lake_info + .io_config + .as_ref() + .map(|cfg| common_io_config::python::IOConfig { + config: cfg.clone(), + }), + ))?; + Ok(py_iter.into()) +} + #[cfg(feature = "python")] impl PhysicalPlan { pub fn to_partition_tasks( @@ -1212,6 +1256,12 @@ impl PhysicalPlan { iceberg_info, input, }) => iceberg_write(py, input.to_partition_tasks(py, psets)?, iceberg_info), + #[cfg(feature = "python")] + PhysicalPlan::DeltaLakeWrite(DeltaLakeWrite { + schema: _, + delta_lake_info, + input, + }) => deltalake_write(py, input.to_partition_tasks(py, psets)?, delta_lake_info), } } } diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index ae612bd3a5..12a4c5fb1a 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -704,6 +704,14 @@ pub(super) fn translate_single_logical_node( )) .arced()) } + crate::sink_info::CatalogType::DeltaLake(deltalake_info) => { + Ok(PhysicalPlan::DeltaLakeWrite(DeltaLakeWrite::new( + schema.clone(), + deltalake_info.clone(), + input_physical, + )) + .arced()) + } }, } } diff --git a/src/daft-plan/src/sink_info.rs b/src/daft-plan/src/sink_info.rs index 02bcb2207f..e3db25ed7f 100644 --- a/src/daft-plan/src/sink_info.rs +++ b/src/daft-plan/src/sink_info.rs @@ -41,6 +41,7 @@ pub struct CatalogInfo { #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum CatalogType { Iceberg(IcebergCatalogInfo), + DeltaLake(DeltaLakeCatalogInfo), } #[cfg(feature = "python")] @@ -100,6 +101,57 @@ impl IcebergCatalogInfo { } } +#[cfg(feature = "python")] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeltaLakeCatalogInfo { + pub path: String, + pub mode: String, + pub current_version: i32, + pub large_dtypes: bool, + pub io_config: Option, +} + +#[cfg(feature = "python")] +impl PartialEq for DeltaLakeCatalogInfo { + fn eq(&self, other: &Self) -> bool { + self.path == other.path + && self.mode == other.mode + && self.current_version == other.current_version + && self.large_dtypes == other.large_dtypes + && self.io_config == other.io_config + } +} + +#[cfg(feature = "python")] +impl Eq for DeltaLakeCatalogInfo {} + +#[cfg(feature = "python")] +impl Hash for DeltaLakeCatalogInfo { + fn hash(&self, state: &mut H) { + self.path.hash(state); + self.mode.hash(state); + self.current_version.hash(state); + self.large_dtypes.hash(state); + self.io_config.hash(state); + } +} + +#[cfg(feature = "python")] +impl DeltaLakeCatalogInfo { + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push(format!("Table Name = {}", self.path)); + res.push(format!("Mode = {}", self.mode)); + res.push(format!("Current Version = {}", self.current_version)); + res.push(format!("Large Dtypes = {}", self.large_dtypes)); + match &self.io_config { + None => res.push("IOConfig = None".to_string()), + Some(io_config) => res.push(format!("IOConfig = {}", io_config)), + }; + res + } +} + impl OutputFileInfo { pub fn new( root_dir: String, diff --git a/tests/io/delta_lake/conftest.py b/tests/io/delta_lake/conftest.py index a137d99bb4..bd2dd735a1 100644 --- a/tests/io/delta_lake/conftest.py +++ b/tests/io/delta_lake/conftest.py @@ -22,8 +22,6 @@ from daft.io.object_store_options import io_config_to_storage_options from tests.io.delta_lake.mock_aws_server import start_service, stop_process -deltalake = pytest.importorskip("deltalake") - @pytest.fixture(params=[1, 2, 8]) def num_partitions(request) -> int: @@ -418,11 +416,16 @@ def local_path(tmp_path: pathlib.Path, data_dir: str) -> tuple[str, None, None]: pytest.param(lazy_fixture("az_path"), marks=(pytest.mark.az, pytest.mark.integration)), ], ) +def cloud_paths(request) -> tuple[str, daft.io.IOConfig | None, DataCatalogTable | None]: + return request.param + + +@pytest.fixture(scope="function") def deltalake_table( - request, base_table: pa.Table, num_partitions: int, partition_generator: callable + cloud_paths, base_table: pa.Table, num_partitions: int, partition_generator: callable ) -> tuple[str, daft.io.IOConfig | None, dict[str, str], list[pa.Table]]: partition_generator, _ = partition_generator - path, io_config, catalog_table = request.param + path, io_config, catalog_table = cloud_paths storage_options = io_config_to_storage_options(io_config, path) if io_config is not None else None parts = [] for i in range(num_partitions): @@ -431,6 +434,7 @@ def deltalake_table( part = base_table.append_column("part_idx", pa.array([part_value if part_value is not None else i] * 3)) parts.append(part) table = pa.concat_tables(parts) + deltalake = pytest.importorskip("deltalake") deltalake.write_deltalake( path, table, diff --git a/tests/io/delta_lake/test_table_read.py b/tests/io/delta_lake/test_table_read.py index edc6560b3f..4797d1376f 100644 --- a/tests/io/delta_lake/test_table_read.py +++ b/tests/io/delta_lake/test_table_read.py @@ -1,16 +1,13 @@ from __future__ import annotations import contextlib - -import pytest - -from daft.io.object_store_options import io_config_to_storage_options - -deltalake = pytest.importorskip("deltalake") +import sys import pyarrow as pa +import pytest import daft +from daft.io.object_store_options import io_config_to_storage_options from daft.logical.schema import Schema from tests.utils import assert_pyarrow_tables_equal @@ -28,10 +25,14 @@ def split_small_pq_files(): PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0) -pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="deltalake only supported if pyarrow >= 8.0.0") +PYTHON_LT_3_8 = sys.version_info[:2] < (3, 8) +pytestmark = pytest.mark.skipif( + PYARROW_LE_8_0_0 or PYTHON_LT_3_8, reason="deltalake only supported if pyarrow >= 8.0.0 and python >= 3.8" +) def test_deltalake_read_basic(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") path = tmp_path / "some_table" deltalake.write_deltalake(path, base_table) df = daft.read_delta_lake(str(path)) @@ -41,6 +42,7 @@ def test_deltalake_read_basic(tmp_path, base_table): def test_deltalake_read_full(deltalake_table): + deltalake = pytest.importorskip("deltalake") path, catalog_table, io_config, parts = deltalake_table df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config) delta_schema = deltalake.DeltaTable(path, storage_options=io_config_to_storage_options(io_config, path)).schema() @@ -56,6 +58,7 @@ def test_deltalake_read_show(deltalake_table): def test_deltalake_read_row_group_splits(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") path = tmp_path / "some_table" # Force 2 rowgroups @@ -69,6 +72,7 @@ def test_deltalake_read_row_group_splits(tmp_path, base_table): def test_deltalake_read_row_group_splits_with_filter(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") path = tmp_path / "some_table" # Force 2 rowgroups @@ -83,6 +87,7 @@ def test_deltalake_read_row_group_splits_with_filter(tmp_path, base_table): def test_deltalake_read_row_group_splits_with_limit(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") path = tmp_path / "some_table" # Force 2 rowgroups diff --git a/tests/io/delta_lake/test_table_read_pushdowns.py b/tests/io/delta_lake/test_table_read_pushdowns.py index 81d543eadd..dbfdb8e7a4 100644 --- a/tests/io/delta_lake/test_table_read_pushdowns.py +++ b/tests/io/delta_lake/test_table_read_pushdowns.py @@ -7,19 +7,25 @@ from daft.io.object_store_options import io_config_to_storage_options deltalake = pytest.importorskip("deltalake") +import sys import pyarrow as pa import pyarrow.compute as pc +import pytest import daft from daft.logical.schema import Schema from tests.utils import assert_pyarrow_tables_equal PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0) -pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="deltalake only supported if pyarrow >= 8.0.0") +PYTHON_LT_3_8 = sys.version_info[:2] < (3, 8) +pytestmark = pytest.mark.skipif( + PYARROW_LE_8_0_0 or PYTHON_LT_3_8, reason="deltalake only supported if pyarrow >= 8.0.0 and python >= 3.8" +) def test_read_predicate_pushdown_on_data(deltalake_table): + deltalake = pytest.importorskip("deltalake") path, catalog_table, io_config, tables = deltalake_table df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config) df = df.where(df["a"] == 2) @@ -32,6 +38,7 @@ def test_read_predicate_pushdown_on_data(deltalake_table): def test_read_predicate_pushdown_on_part(deltalake_table, partition_generator): + deltalake = pytest.importorskip("deltalake") path, catalog_table, io_config, tables = deltalake_table df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config) part_idx = 2 @@ -50,6 +57,7 @@ def test_read_predicate_pushdown_on_part(deltalake_table, partition_generator): def test_read_predicate_pushdown_on_part_non_eq(deltalake_table, partition_generator): + deltalake = pytest.importorskip("deltalake") path, catalog_table, io_config, tables = deltalake_table df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config) part_idx = 3 @@ -68,6 +76,7 @@ def test_read_predicate_pushdown_on_part_non_eq(deltalake_table, partition_gener def test_read_predicate_pushdown_on_part_and_data(deltalake_table, partition_generator): + deltalake = pytest.importorskip("deltalake") path, catalog_table, io_config, tables = deltalake_table df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config) part_idx = 2 @@ -91,6 +100,7 @@ def test_read_predicate_pushdown_on_part_and_data(deltalake_table, partition_gen def test_read_predicate_pushdown_on_part_and_data_same_clause(deltalake_table, partition_generator): + deltalake = pytest.importorskip("deltalake") path, catalog_table, io_config, tables = deltalake_table df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config) partition_generator, col = partition_generator @@ -105,6 +115,7 @@ def test_read_predicate_pushdown_on_part_and_data_same_clause(deltalake_table, p def test_read_predicate_pushdown_on_part_empty(deltalake_table, partition_generator, num_partitions): + deltalake = pytest.importorskip("deltalake") partition_generator, _ = partition_generator path, catalog_table, io_config, tables = deltalake_table df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config) diff --git a/tests/io/delta_lake/test_table_write.py b/tests/io/delta_lake/test_table_write.py new file mode 100644 index 0000000000..3133469745 --- /dev/null +++ b/tests/io/delta_lake/test_table_write.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import contextlib +import sys + +import pyarrow as pa +import pytest + +import daft +from daft.io.object_store_options import io_config_to_storage_options +from daft.logical.schema import Schema + + +@contextlib.contextmanager +def split_small_pq_files(): + old_config = daft.context.get_context().daft_execution_config + daft.set_execution_config( + # Splits any parquet files >100 bytes in size + scan_tasks_min_size_bytes=1, + scan_tasks_max_size_bytes=100, + ) + yield + daft.set_execution_config(config=old_config) + + +PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0) +PYTHON_LT_3_8 = sys.version_info[:2] < (3, 8) +pytestmark = pytest.mark.skipif( + PYARROW_LE_8_0_0 or PYTHON_LT_3_8, reason="deltalake only supported if pyarrow >= 8.0.0 and python >= 3.8" +) + + +def test_deltalake_write_basic(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") + path = tmp_path / "some_table" + df = daft.from_arrow(base_table) + df.write_delta(str(path)) + read_delta = deltalake.DeltaTable(str(path)) + expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) + assert df.schema() == expected_schema + assert read_delta.to_pyarrow_table() == base_table + + +def test_deltalake_multi_write_basic(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") + path = tmp_path / "some_table" + df = daft.from_arrow(base_table) + df.write_delta(str(path)) + df.write_delta(str(path)) + read_delta = deltalake.DeltaTable(str(path)) + expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) + assert df.schema() == expected_schema + assert read_delta.version() == 1 + assert read_delta.to_pyarrow_table() == pa.concat_tables([base_table, base_table]) + + +def test_deltalake_write_cloud(base_table, cloud_paths): + deltalake = pytest.importorskip("deltalake") + path, io_config, catalog_table = cloud_paths + df = daft.from_arrow(base_table) + df.write_delta(str(path), io_config=io_config) + storage_options = io_config_to_storage_options(io_config, path) if io_config is not None else None + read_delta = deltalake.DeltaTable(str(path), storage_options=storage_options) + expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) + assert df.schema() == expected_schema + assert read_delta.to_pyarrow_table() == base_table