Skip to content

Commit

Permalink
[CHORE] Enable lancedb reads for native executor (#2925)
Browse files Browse the repository at this point in the history
Enables Lance reads for swordfish. 

The existing lance read mechanism reads from python iterators. This PR
abstracts out that logic to a separate function.

Both the python and streaming executors will call this function, the
python executor will do a `.collect` on the iterator, while the native
runner will wrap the iterator in a stream.

---------

Co-authored-by: Colin Ho <[email protected]>
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
3 people authored Oct 17, 2024
1 parent d243cee commit 69fef20
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 114 deletions.
7 changes: 3 additions & 4 deletions src/daft-local-execution/src/sources/scan_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,13 @@ async fn stream_scan_task(
.map(|t| t.into())
.context(PyIOSnafu)
})?;
// SQL Scan cannot be streamed at the moment, so we just return the table
Box::pin(futures::stream::once(async { Ok(table) }))
}
#[cfg(feature = "python")]
FileFormatConfig::PythonFunction => {
return Err(common_error::DaftError::TypeError(
"PythonFunction file format not implemented".to_string(),
));
let iter = daft_micropartition::python::read_pyfunc_into_table_iter(&scan_task)?;
let stream = futures::stream::iter(iter.map(|r| r.map_err(|e| e.into())));
Box::pin(stream)
}
};

Expand Down
101 changes: 2 additions & 99 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,105 +379,8 @@ fn materialize_scan_task(
})?
}
FileFormatConfig::PythonFunction => {
use pyo3::{types::PyAnyMethods, PyObject};

let table_iterators = scan_task.sources.iter().map(|source| {
// Call Python function to create an Iterator (Grabs the GIL and then releases it)
match source {
DataSource::PythonFactoryFunction {
module,
func_name,
func_args,
..
} => {
Python::with_gil(|py| {
let func = py.import_bound(module.as_str())
.unwrap_or_else(|_| panic!("Cannot import factory function from module {module}"))
.getattr(func_name.as_str())
.unwrap_or_else(|_| panic!("Cannot find function {func_name} in module {module}"));
func.call(func_args.to_pytuple(py), None)
.with_context(|_| PyIOSnafu)
.map(Into::<PyObject>::into)
})
}
_ => unreachable!("PythonFunction file format must be paired with PythonFactoryFunction data file sources"),
}
});

let mut tables = Vec::new();
let mut rows_seen_so_far = 0;
for iterator in table_iterators {
let iterator = iterator?;

// Iterate on this iterator to exhaustion, or until the limit is met
while scan_task
.pushdowns
.limit
.map_or(true, |limit| rows_seen_so_far < limit)
{
// Grab the GIL to call next() on the iterator, and then release it once we have the Table
let table = match Python::with_gil(|py| {
iterator
.downcast_bound::<pyo3::types::PyIterator>(py)
.expect("Function must return an iterator of tables")
.clone()
.next()
.map(|result| {
result
.map(|tbl| {
tbl.extract::<daft_table::python::PyTable>()
.expect("Must be a PyTable")
.table
})
.with_context(|_| PyIOSnafu)
})
}) {
Some(table) => table,
None => break,
}?;

// Apply filters
let table = if let Some(filters) = scan_task.pushdowns.filters.as_ref()
{
table
.filter(&[filters.clone()])
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Apply limit if necessary, and update `&mut remaining`
let table = if let Some(limit) = scan_task.pushdowns.limit {
let limited_table = if rows_seen_so_far + table.len() > limit {
table
.slice(0, limit - rows_seen_so_far)
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Update the rows_seen_so_far
rows_seen_so_far += limited_table.len();

limited_table
} else {
table
};

tables.push(table);
}

// If seen enough rows, early-terminate
if scan_task
.pushdowns
.limit
.is_some_and(|limit| rows_seen_so_far >= limit)
{
break;
}
}

tables
let tables = crate::python::read_pyfunc_into_table_iter(&scan_task)?;
tables.collect::<crate::Result<Vec<_>>>()?
}
}
}
Expand Down
106 changes: 103 additions & 3 deletions src/daft-micropartition/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@ use daft_dsl::python::PyExpr;
use daft_io::{python::IOConfig, IOStatsContext};
use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions};
use daft_parquet::read::ParquetSchemaInferenceOptions;
use daft_scan::{python::pylib::PyScanTask, storage_config::PyStorageConfig, ScanTask};
use daft_scan::{
python::pylib::PyScanTask, storage_config::PyStorageConfig, DataSource, ScanTask, ScanTaskRef,
};
use daft_stats::{TableMetadata, TableStatistics};
use daft_table::python::PyTable;
use daft_table::{python::PyTable, Table};
use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes, PyTypeInfo};
use snafu::ResultExt;

use crate::micropartition::{MicroPartition, TableState};
use crate::{
micropartition::{MicroPartition, TableState},
DaftCoreComputeSnafu, PyIOSnafu,
};

#[pyclass(module = "daft.daft", frozen)]
#[derive(Clone)]
Expand Down Expand Up @@ -902,6 +908,100 @@ pub fn read_sql_into_py_table(
.extract()
}

pub fn read_pyfunc_into_table_iter(
scan_task: &ScanTaskRef,
) -> crate::Result<impl Iterator<Item = crate::Result<Table>>> {
let table_iterators = scan_task.sources.iter().map(|source| {
// Call Python function to create an Iterator (Grabs the GIL and then releases it)
match source {
DataSource::PythonFactoryFunction {
module,
func_name,
func_args,
..
} => {
Python::with_gil(|py| {
let func = py.import_bound(module.as_str())
.unwrap_or_else(|_| panic!("Cannot import factory function from module {module}"))
.getattr(func_name.as_str())
.unwrap_or_else(|_| panic!("Cannot find function {func_name} in module {module}"));
func.call(func_args.to_pytuple(py), None)
.with_context(|_| PyIOSnafu)
.map(Into::<PyObject>::into)
})
},
_ => unreachable!("PythonFunction file format must be paired with PythonFactoryFunction data file sources"),
}
}).collect::<crate::Result<Vec<_>>>()?;

let scan_task_limit = scan_task.pushdowns.limit;
let scan_task_filters = scan_task.pushdowns.filters.clone();
let res = table_iterators
.into_iter()
.filter_map(|iter| {
Python::with_gil(|py| {
iter.downcast_bound::<pyo3::types::PyIterator>(py)
.expect("Function must return an iterator of tables")
.clone()
.next()
.map(|result| {
result
.map(|tbl| {
tbl.extract::<daft_table::python::PyTable>()
.expect("Must be a PyTable")
.table
})
.with_context(|_| PyIOSnafu)
})
})
})
.scan(0, move |rows_seen_so_far, table| {
if scan_task_limit
.map(|limit| *rows_seen_so_far >= limit)
.unwrap_or(false)
{
return None;
}
match table {
Err(e) => Some(Err(e)),
Ok(table) => {
// Apply filters
let post_pushdown_table = || -> crate::Result<Table> {
let table = if let Some(filters) = scan_task_filters.as_ref() {
table
.filter(&[filters.clone()])
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Apply limit if necessary, and update `&mut remaining`
if let Some(limit) = scan_task_limit {
let limited_table = if *rows_seen_so_far + table.len() > limit {
table
.slice(0, limit - *rows_seen_so_far)
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Update the rows_seen_so_far
*rows_seen_so_far += limited_table.len();

Ok(limited_table)
} else {
Ok(table)
}
}();

Some(post_pushdown_table)
}
}
});

Ok(res)
}

impl From<MicroPartition> for PyMicroPartition {
fn from(value: MicroPartition) -> Self {
Arc::new(value).into()
Expand Down
9 changes: 1 addition & 8 deletions tests/io/lancedb/test_lancedb_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
import pytest

import daft
from daft import context

native_executor_skip = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)

TABLE_NAME = "my_table"
data = {
Expand All @@ -18,8 +12,7 @@
}

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
py_arrow_skip = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0")
pytestmark = [native_executor_skip, py_arrow_skip]
pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0")


@pytest.fixture(scope="function")
Expand Down

0 comments on commit 69fef20

Please sign in to comment.