From 3a6fb39e0d71ff7d27a58cc5078e2de195ef216f Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Mon, 11 Sep 2023 23:02:33 -0600 Subject: [PATCH] pass arrow schema to pyarrow --- daft/table/table.py | 6 ++++-- src/daft-core/src/ffi.rs | 16 ++++++++++++++++ src/daft-parquet/src/file.rs | 8 ++++---- src/daft-parquet/src/python.rs | 17 ++++++++++++----- src/daft-parquet/src/read.rs | 18 ++++++++---------- 5 files changed, 44 insertions(+), 21 deletions(-) diff --git a/daft/table/table.py b/daft/table/table.py index 4ac8b357dd..e3ee84a3eb 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -465,7 +465,7 @@ def read_parquet_into_pyarrow( coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(), ) -> pa.Table: - names, columns = _read_parquet_into_pyarrow( + fields, columns = _read_parquet_into_pyarrow( uri=path, columns=columns, start_offset=start_offset, @@ -475,5 +475,7 @@ def read_parquet_into_pyarrow( multithreaded_io=multithreaded_io, coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit, ) + + schema = pa.schema(fields) columns = [pa.chunked_array(c) for c in columns] - return pa.table(columns, names=names) + return pa.table(columns, schema=schema) diff --git a/src/daft-core/src/ffi.rs b/src/daft-core/src/ffi.rs index c2e82ca988..b92e17e8c8 100644 --- a/src/daft-core/src/ffi.rs +++ b/src/daft-core/src/ffi.rs @@ -51,6 +51,22 @@ pub fn to_py_array(array: ArrayRef, py: Python, pyarrow: &PyModule) -> PyResult< Ok(array.to_object(py)) } +pub fn field_to_py( + field: &arrow2::datatypes::Field, + py: Python, + pyarrow: &PyModule, +) -> PyResult { + let schema = Box::new(ffi::export_field_to_c(field)); + let schema_ptr: *const ffi::ArrowSchema = &*schema; + + let field = pyarrow.getattr(pyo3::intern!(py, "Field"))?.call_method1( + pyo3::intern!(py, "_import_from_c"), + (schema_ptr as Py_uintptr_t,), + )?; + + Ok(field.to_object(py)) +} + pub fn to_py_schema( dtype: &arrow2::datatypes::DataType, py: Python, diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 4d6104789b..866e28ffe4 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -236,7 +236,7 @@ struct RowGroupRange { pub(crate) struct ParquetFileReader { uri: String, metadata: Arc, - arrow_schema: arrow2::datatypes::Schema, + arrow_schema: arrow2::datatypes::SchemaRef, row_ranges: Arc>, } @@ -250,12 +250,12 @@ impl ParquetFileReader { Ok(ParquetFileReader { uri, metadata: Arc::new(metadata), - arrow_schema, + arrow_schema: arrow_schema.into(), row_ranges: Arc::new(row_ranges), }) } - pub fn arrow_schema(&self) -> &arrow2::datatypes::Schema { + pub fn arrow_schema(&self) -> &Arc { &self.arrow_schema } @@ -469,7 +469,7 @@ impl ParquetFileReader { })? .into_iter() .collect::>>()?; - let daft_schema = daft_core::schema::Schema::try_from(&self.arrow_schema)?; + let daft_schema = daft_core::schema::Schema::try_from(self.arrow_schema.as_ref())?; Table::new(daft_schema, all_series) } diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index 94db5bdf9a..a27149f353 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -1,7 +1,10 @@ use pyo3::prelude::*; pub mod pylib { - use daft_core::python::{datatype::PyTimeUnit, schema::PySchema, PySeries}; + use daft_core::{ + ffi::field_to_py, + python::{datatype::PyTimeUnit, schema::PySchema, PySeries}, + }; use daft_io::{get_io_client, python::IOConfig}; use daft_table::python::PyTable; use pyo3::{pyfunction, PyResult, Python}; @@ -56,7 +59,7 @@ pub mod pylib { io_config: Option, multithreaded_io: Option, coerce_int96_timestamp_unit: Option, - ) -> PyResult<(Vec, Vec>)> { + ) -> PyResult<(Vec, Vec>)> { let read_parquet_result = py.allow_threads(|| { let io_client = get_io_client( multithreaded_io.unwrap_or(true), @@ -76,7 +79,7 @@ pub mod pylib { &schema_infer_options, ) })?; - let (names, all_arrays) = read_parquet_result; + let (schema, all_arrays) = read_parquet_result; let pyarrow = py.import("pyarrow")?; let converted_arrays = all_arrays .into_iter() @@ -86,8 +89,12 @@ pub mod pylib { .collect::>>() }) .collect::>>()?; - - Ok((names, converted_arrays)) + let fields = schema + .fields + .iter() + .map(|f| field_to_py(f, py, pyarrow)) + .collect::, _>>()?; + Ok((fields, converted_arrays)) } #[allow(clippy::too_many_arguments)] diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 229e6db726..fee1172771 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -146,7 +146,10 @@ async fn read_parquet_single_into_arrow( row_groups: Option<&[i64]>, io_client: Arc, schema_infer_options: &ParquetSchemaInferenceOptions, -) -> DaftResult<(Vec, Vec>>)> { +) -> DaftResult<( + arrow2::datatypes::SchemaRef, + Vec>>, +)> { let builder = ParquetReaderBuilder::from_uri(uri, io_client.clone()).await?; let builder = builder.set_infer_schema_options(schema_infer_options); @@ -180,12 +183,7 @@ async fn read_parquet_single_into_arrow( let parquet_reader = builder.build()?; - let schema = parquet_reader.arrow_schema(); - let names = schema - .fields - .iter() - .map(|f| f.name.to_string()) - .collect::>(); + let schema = parquet_reader.arrow_schema().clone(); let ranges = parquet_reader.prebuffer_ranges(io_client)?; let all_arrays = parquet_reader .read_from_ranges_into_arrow_arrays(ranges) @@ -257,7 +255,7 @@ async fn read_parquet_single_into_arrow( .into()); } - Ok((names, all_arrays)) + Ok((schema, all_arrays)) } #[allow(clippy::too_many_arguments)] @@ -298,7 +296,7 @@ pub fn read_parquet_into_pyarrow( io_client: Arc, multithreaded_io: bool, schema_infer_options: &ParquetSchemaInferenceOptions, -) -> DaftResult<(Vec, Vec)> { +) -> DaftResult<(arrow2::datatypes::SchemaRef, Vec)> { let runtime_handle = get_runtime(multithreaded_io)?; let _rt_guard = runtime_handle.enter(); runtime_handle.block_on(async { @@ -383,7 +381,7 @@ pub fn read_parquet_schema( .block_on(async { ParquetReaderBuilder::from_uri(uri, io_client.clone()).await })?; let builder = builder.set_infer_schema_options(schema_inference_options); - Schema::try_from(builder.build()?.arrow_schema()) + Schema::try_from(builder.build()?.arrow_schema().as_ref()) } pub fn read_parquet_statistics(uris: &Series, io_client: Arc) -> DaftResult {