diff --git a/benchmarking/parquet/conftest.py b/benchmarking/parquet/conftest.py index 54d16ef52c..24b3a92017 100644 --- a/benchmarking/parquet/conftest.py +++ b/benchmarking/parquet/conftest.py @@ -80,6 +80,10 @@ def daft_bulk_read(paths: list[str], columns: list[str] | None = None) -> list[p return [t.to_arrow() for t in tables] +def daft_into_pyarrow_bulk_read(paths: list[str], columns: list[str] | None = None) -> list[pa.Table]: + return daft.table.read_parquet_into_pyarrow_bulk(paths, columns=columns) + + def pyarrow_bulk_read(paths: list[str], columns: list[str] | None = None) -> list[pa.Table]: return [pyarrow_read(f, columns=columns) for f in paths] @@ -91,11 +95,13 @@ def boto_bulk_read(paths: list[str], columns: list[str] | None = None) -> list[p @pytest.fixture( params=[ daft_bulk_read, + daft_into_pyarrow_bulk_read, pyarrow_bulk_read, boto_bulk_read, ], ids=[ "daft_bulk_read", + "daft_into_pyarrow_bulk_read", "pyarrow_bulk_read", "boto3_bulk_read", ], diff --git a/daft/daft.pyi b/daft/daft.pyi index e96b2eaa6f..db16d78a06 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -391,6 +391,16 @@ def read_parquet_into_pyarrow( multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: PyTimeUnit | None = None, ): ... +def read_parquet_into_pyarrow_bulk( + uris: list[str], + columns: list[str] | None = None, + start_offset: int | None = None, + num_rows: int | None = None, + row_groups: list[list[int]] | None = None, + io_config: IOConfig | None = None, + multithreaded_io: bool | None = None, + coerce_int96_timestamp_unit: PyTimeUnit | None = None, +): ... def read_parquet_schema( uri: str, io_config: IOConfig | None = None, diff --git a/daft/table/__init__.py b/daft/table/__init__.py index e4a55d9e8d..f2fb68f1d8 100644 --- a/daft/table/__init__.py +++ b/daft/table/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .table import Table, read_parquet_into_pyarrow +from .table import Table, read_parquet_into_pyarrow, read_parquet_into_pyarrow_bulk -__all__ = ["Table", "read_parquet_into_pyarrow"] +__all__ = ["Table", "read_parquet_into_pyarrow", "read_parquet_into_pyarrow_bulk"] diff --git a/daft/table/table.py b/daft/table/table.py index 7a0b0831af..5893d63b39 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -12,6 +12,7 @@ from daft.daft import read_parquet as _read_parquet from daft.daft import read_parquet_bulk as _read_parquet_bulk from daft.daft import read_parquet_into_pyarrow as _read_parquet_into_pyarrow +from daft.daft import read_parquet_into_pyarrow_bulk as _read_parquet_into_pyarrow_bulk from daft.daft import read_parquet_statistics as _read_parquet_statistics from daft.datatype import DataType, TimeUnit from daft.expressions import Expression, ExpressionsProjection @@ -476,3 +477,29 @@ def read_parquet_into_pyarrow( schema = pa.schema(fields, metadata=metadata) columns = [pa.chunked_array(c) for c in columns] # type: ignore return pa.table(columns, schema=schema) + + +def read_parquet_into_pyarrow_bulk( + paths: list[str], + columns: list[str] | None = None, + start_offset: int | None = None, + num_rows: int | None = None, + row_groups_per_path: list[list[int]] | None = None, + io_config: IOConfig | None = None, + multithreaded_io: bool | None = None, + coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(), +) -> list[pa.Table]: + bulk_result = _read_parquet_into_pyarrow_bulk( + uris=paths, + columns=columns, + start_offset=start_offset, + num_rows=num_rows, + row_groups=row_groups_per_path, + io_config=io_config, + multithreaded_io=multithreaded_io, + coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit, + ) + return [ + pa.table([pa.chunked_array(c) for c in columns], schema=pa.schema(fields, metadata=metadata)) # type: ignore + for fields, metadata, columns in bulk_result + ] diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index a30bf80373..f0491b872a 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -7,10 +7,10 @@ pub mod pylib { }; use daft_io::{get_io_client, python::IOConfig}; use daft_table::python::PyTable; - use pyo3::{pyfunction, PyResult, Python}; + use pyo3::{pyfunction, types::PyModule, PyResult, Python}; use std::{collections::BTreeMap, sync::Arc}; - use crate::read::ParquetSchemaInferenceOptions; + use crate::read::{ArrowChunk, ParquetSchemaInferenceOptions}; use daft_core::ffi::to_py_array; #[allow(clippy::too_many_arguments)] #[pyfunction] @@ -48,6 +48,29 @@ pub mod pylib { } type PyArrowChunks = Vec>; type PyArrowFields = Vec; + type PyArrowParquetType = (PyArrowFields, BTreeMap, PyArrowChunks); + fn convert_pyarrow_parquet_read_result_into_py( + py: Python, + schema: arrow2::datatypes::SchemaRef, + all_arrays: Vec, + pyarrow: &PyModule, + ) -> PyResult { + let converted_arrays = all_arrays + .into_iter() + .map(|v| { + v.into_iter() + .map(|a| to_py_array(a, py, pyarrow)) + .collect::>>() + }) + .collect::>>()?; + let fields = schema + .fields + .iter() + .map(|f| field_to_py(f, py, pyarrow)) + .collect::, _>>()?; + let metadata = &schema.metadata; + Ok((fields, metadata.clone(), converted_arrays)) + } #[allow(clippy::too_many_arguments)] #[pyfunction] @@ -61,7 +84,7 @@ pub mod pylib { io_config: Option, multithreaded_io: Option, coerce_int96_timestamp_unit: Option, - ) -> PyResult<(PyArrowFields, BTreeMap, PyArrowChunks)> { + ) -> PyResult { let read_parquet_result = py.allow_threads(|| { let io_client = get_io_client( multithreaded_io.unwrap_or(true), @@ -83,24 +106,8 @@ pub mod pylib { })?; let (schema, all_arrays) = read_parquet_result; let pyarrow = py.import("pyarrow")?; - let converted_arrays = all_arrays - .into_iter() - .map(|v| { - v.into_iter() - .map(|a| to_py_array(a, py, pyarrow)) - .collect::>>() - }) - .collect::>>()?; - let fields = schema - .fields - .iter() - .map(|f| field_to_py(f, py, pyarrow)) - .collect::, _>>()?; - let metadata = &schema.metadata; - - Ok((fields, metadata.clone(), converted_arrays)) + convert_pyarrow_parquet_read_result_into_py(py, schema, all_arrays, pyarrow) } - #[allow(clippy::too_many_arguments)] #[pyfunction] pub fn read_parquet_bulk( @@ -139,6 +146,48 @@ pub mod pylib { }) } + #[allow(clippy::too_many_arguments)] + #[pyfunction] + pub fn read_parquet_into_pyarrow_bulk( + py: Python, + uris: Vec<&str>, + columns: Option>, + start_offset: Option, + num_rows: Option, + row_groups: Option>>, + io_config: Option, + multithreaded_io: Option, + coerce_int96_timestamp_unit: Option, + ) -> PyResult> { + let parquet_read_results = py.allow_threads(|| { + let io_client = get_io_client( + multithreaded_io.unwrap_or(true), + io_config.unwrap_or_default().config.into(), + )?; + let schema_infer_options = ParquetSchemaInferenceOptions::new( + coerce_int96_timestamp_unit.map(|tu| tu.timeunit), + ); + + crate::read::read_parquet_into_pyarrow_bulk( + uris.as_ref(), + columns.as_deref(), + start_offset, + num_rows, + row_groups, + io_client, + multithreaded_io.unwrap_or(true), + &schema_infer_options, + ) + })?; + let pyarrow = py.import("pyarrow")?; + parquet_read_results + .into_iter() + .map(|(s, all_arrays)| { + convert_pyarrow_parquet_read_result_into_py(py, s, all_arrays, pyarrow) + }) + .collect::>>() + } + #[pyfunction] pub fn read_parquet_schema( py: Python, @@ -183,7 +232,7 @@ pub mod pylib { pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet))?; parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet_into_pyarrow))?; - + parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet_into_pyarrow_bulk))?; parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet_bulk))?; parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet_schema))?; parent.add_wrapped(wrap_pyfunction!(pylib::read_parquet_statistics))?; diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index fee1172771..249abaff36 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -285,7 +285,7 @@ pub fn read_parquet( }) } pub type ArrowChunk = Vec>; - +pub type ParquetPyarrowChunk = (arrow2::datatypes::SchemaRef, Vec); #[allow(clippy::too_many_arguments)] pub fn read_parquet_into_pyarrow( uri: &str, @@ -296,7 +296,7 @@ pub fn read_parquet_into_pyarrow( io_client: Arc, multithreaded_io: bool, schema_infer_options: &ParquetSchemaInferenceOptions, -) -> DaftResult<(arrow2::datatypes::SchemaRef, Vec)> { +) -> DaftResult { let runtime_handle = get_runtime(multithreaded_io)?; let _rt_guard = runtime_handle.enter(); runtime_handle.block_on(async { @@ -370,6 +370,63 @@ pub fn read_parquet_bulk( tables.into_iter().collect::>>() } +#[allow(clippy::too_many_arguments)] +pub fn read_parquet_into_pyarrow_bulk( + uris: &[&str], + columns: Option<&[&str]>, + start_offset: Option, + num_rows: Option, + row_groups: Option>>, + io_client: Arc, + multithreaded_io: bool, + schema_infer_options: &ParquetSchemaInferenceOptions, +) -> DaftResult> { + let runtime_handle = get_runtime(multithreaded_io)?; + let _rt_guard = runtime_handle.enter(); + let owned_columns = columns.map(|s| s.iter().map(|v| String::from(*v)).collect::>()); + if let Some(ref row_groups) = row_groups { + if row_groups.len() != uris.len() { + return Err(common_error::DaftError::ValueError(format!( + "Mismatch of length of `uris` and `row_groups`. {} vs {}", + uris.len(), + row_groups.len() + ))); + } + } + let tables = runtime_handle + .block_on(async move { + try_join_all(uris.iter().enumerate().map(|(i, uri)| { + let uri = uri.to_string(); + let owned_columns = owned_columns.clone(); + let owned_row_group = match &row_groups { + None => None, + Some(v) => v.get(i).cloned(), + }; + + let io_client = io_client.clone(); + let schema_infer_options = schema_infer_options.clone(); + tokio::task::spawn(async move { + let columns = owned_columns + .as_ref() + .map(|s| s.iter().map(AsRef::as_ref).collect::>()); + read_parquet_single_into_arrow( + &uri, + columns.as_deref(), + start_offset, + num_rows, + owned_row_group.as_deref(), + io_client, + &schema_infer_options, + ) + .await + }) + })) + .await + }) + .context(JoinSnafu { path: "UNKNOWN" })?; + tables.into_iter().collect::>>() +} + pub fn read_parquet_schema( uri: &str, io_client: Arc, diff --git a/tests/integration/io/parquet/test_reads_public_data.py b/tests/integration/io/parquet/test_reads_public_data.py index ee6ba8698a..e47a83b2bd 100644 --- a/tests/integration/io/parquet/test_reads_public_data.py +++ b/tests/integration/io/parquet/test_reads_public_data.py @@ -250,6 +250,26 @@ def test_parquet_read_table_bulk(parquet_file, public_storage_io_config, multith pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas()) +@pytest.mark.integration() +@pytest.mark.skipif( + daft.context.get_context().use_rust_planner, reason="Custom fsspec filesystems not supported in new query planner" +) +@pytest.mark.parametrize( + "multithreaded_io", + [False, True], +) +def test_parquet_into_pyarrow_bulk(parquet_file, public_storage_io_config, multithreaded_io): + _, url = parquet_file + daft_native_reads = daft.table.read_parquet_into_pyarrow_bulk( + [url] * 2, io_config=public_storage_io_config, multithreaded_io=multithreaded_io + ) + pa_read = read_parquet_with_pyarrow(url) + + for daft_native_read in daft_native_reads: + assert daft_native_read.schema == pa_read.schema + pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas()) + + @pytest.mark.integration() def test_parquet_read_df(parquet_file, public_storage_io_config): _, url = parquet_file @@ -321,3 +341,22 @@ def test_row_groups_selection_bulk(public_storage_io_config, multithreaded_io): for i, t in enumerate(rest): assert len(t) == 10 assert first.to_arrow()[i * 10 : (i + 1) * 10] == t.to_arrow() + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "multithreaded_io", + [False, True], +) +def test_row_groups_selection_into_pyarrow_bulk(public_storage_io_config, multithreaded_io): + url = ["s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet"] * 11 + row_groups = [list(range(10))] + [[i] for i in range(10)] + first, *rest = daft.table.read_parquet_into_pyarrow_bulk( + url, io_config=public_storage_io_config, multithreaded_io=multithreaded_io, row_groups_per_path=row_groups + ) + assert len(first) == 100 + assert len(rest) == 10 + + for i, t in enumerate(rest): + assert len(t) == 10 + assert first[i * 10 : (i + 1) * 10] == t