Skip to content

Commit

Permalink
lance reads
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 2, 2024
1 parent f0ce5c4 commit 22eab9b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 116 deletions.
10 changes: 6 additions & 4 deletions src/daft-local-execution/src/sources/scan_task.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{pin::Pin, sync::Arc};

use common_error::DaftResult;
use common_file_formats::{FileFormatConfig, ParquetSourceConfig};
Expand All @@ -8,6 +8,7 @@ use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions};
use daft_micropartition::MicroPartition;
use daft_parquet::read::ParquetSchemaInferenceOptions;
use daft_scan::{storage_config::StorageConfig, ChunkSpec, ScanTask};
use daft_table::Table;
use futures::{Stream, StreamExt};
use snafu::ResultExt;
use tokio_stream::wrappers::ReceiverStream;
Expand Down Expand Up @@ -293,9 +294,10 @@ async fn stream_scan_task(
}
#[cfg(feature = "python")]
FileFormatConfig::PythonFunction => {
return Err(common_error::DaftError::TypeError(
"PythonFunction file format not implemented".to_string(),
));
let iter =
daft_micropartition::python::read_pyfunc_into_table_iter(&scan_task)?;
let stream = futures::stream::iter(iter.map(|r| r.map_err(|e| e.into())));
Box::pin(stream) as Pin<Box<dyn Stream<Item = DaftResult<Table>> + Send>>
}
};

Expand Down
103 changes: 2 additions & 101 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,107 +383,8 @@ fn materialize_scan_task(
})?
}
FileFormatConfig::PythonFunction => {
use pyo3::{types::PyAnyMethods, PyObject};

let table_iterators = scan_task.sources.iter().map(|source| {
// Call Python function to create an Iterator (Grabs the GIL and then releases it)
match source {
DataSource::PythonFactoryFunction {
module,
func_name,
func_args,
..
} => {
Python::with_gil(|py| {
let func = py.import_bound(module.as_str())
.unwrap_or_else(|_| panic!("Cannot import factory function from module {module}"))
.getattr(func_name.as_str())
.unwrap_or_else(|_| panic!("Cannot find function {func_name} in module {module}"));
func.call(func_args.to_pytuple(py), None)
.with_context(|_| PyIOSnafu)
.map(Into::<PyObject>::into)
})
},
_ => unreachable!("PythonFunction file format must be paired with PythonFactoryFunction data file sources"),
}
});

let mut tables = Vec::new();
let mut rows_seen_so_far = 0;
for iterator in table_iterators {
let iterator = iterator?;

// Iterate on this iterator to exhaustion, or until the limit is met
while scan_task
.pushdowns
.limit
.map(|limit| rows_seen_so_far < limit)
.unwrap_or(true)
{
// Grab the GIL to call next() on the iterator, and then release it once we have the Table
let table = match Python::with_gil(|py| {
iterator
.downcast_bound::<pyo3::types::PyIterator>(py)
.expect("Function must return an iterator of tables")
.clone()
.next()
.map(|result| {
result
.map(|tbl| {
tbl.extract::<daft_table::python::PyTable>()
.expect("Must be a PyTable")
.table
})
.with_context(|_| PyIOSnafu)
})
}) {
Some(table) => table,
None => break,
}?;

// Apply filters
let table = if let Some(filters) = scan_task.pushdowns.filters.as_ref()
{
table
.filter(&[filters.clone()])
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Apply limit if necessary, and update `&mut remaining`
let table = if let Some(limit) = scan_task.pushdowns.limit {
let limited_table = if rows_seen_so_far + table.len() > limit {
table
.slice(0, limit - rows_seen_so_far)
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Update the rows_seen_so_far
rows_seen_so_far += limited_table.len();

limited_table
} else {
table
};

tables.push(table);
}

// If seen enough rows, early-terminate
if scan_task
.pushdowns
.limit
.map(|limit| rows_seen_so_far >= limit)
.unwrap_or(false)
{
break;
}
}

tables
let tables = crate::python::read_pyfunc_into_table_iter(&scan_task)?;
tables.collect::<crate::Result<Vec<_>>>()?
}
}
}
Expand Down
106 changes: 103 additions & 3 deletions src/daft-micropartition/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@ use daft_dsl::python::PyExpr;
use daft_io::{python::IOConfig, IOStatsContext};
use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions};
use daft_parquet::read::ParquetSchemaInferenceOptions;
use daft_scan::{python::pylib::PyScanTask, storage_config::PyStorageConfig, ScanTask};
use daft_scan::{
python::pylib::PyScanTask, storage_config::PyStorageConfig, DataSource, ScanTask, ScanTaskRef,
};
use daft_stats::{TableMetadata, TableStatistics};
use daft_table::python::PyTable;
use daft_table::{python::PyTable, Table};
use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes, PyTypeInfo};
use snafu::ResultExt;

use crate::micropartition::{MicroPartition, TableState};
use crate::{
micropartition::{MicroPartition, TableState},
DaftCoreComputeSnafu, PyIOSnafu,
};

#[pyclass(module = "daft.daft", frozen)]
#[derive(Clone)]
Expand Down Expand Up @@ -882,6 +888,100 @@ pub fn read_sql_into_py_table(
.extract()
}

pub fn read_pyfunc_into_table_iter(
scan_task: &ScanTaskRef,
) -> crate::Result<impl Iterator<Item = crate::Result<Table>>> {
let table_iterators = scan_task.sources.iter().map(|source| {
// Call Python function to create an Iterator (Grabs the GIL and then releases it)
match source {
DataSource::PythonFactoryFunction {
module,
func_name,
func_args,
..
} => {
Python::with_gil(|py| {
let func = py.import_bound(module.as_str())
.unwrap_or_else(|_| panic!("Cannot import factory function from module {module}"))
.getattr(func_name.as_str())
.unwrap_or_else(|_| panic!("Cannot find function {func_name} in module {module}"));
func.call(func_args.to_pytuple(py), None)
.with_context(|_| PyIOSnafu)
.map(Into::<PyObject>::into)
})
},
_ => unreachable!("PythonFunction file format must be paired with PythonFactoryFunction data file sources"),
}
}).collect::<crate::Result<Vec<_>>>()?;

let scan_task_limit = scan_task.pushdowns.limit;
let scan_task_filters = scan_task.pushdowns.filters.clone();
let res = table_iterators
.into_iter()
.flat_map(|iter| {
Python::with_gil(|py| {
iter.downcast_bound::<pyo3::types::PyIterator>(py)
.expect("Function must return an iterator of tables")
.clone()
.next()
.map(|result| {
result
.map(|tbl| {
tbl.extract::<daft_table::python::PyTable>()
.expect("Must be a PyTable")
.table
})
.with_context(|_| PyIOSnafu)
})
})
})
.scan(0, move |rows_seen_so_far, table| {
if scan_task_limit
.map(|limit| *rows_seen_so_far >= limit)
.unwrap_or(false)
{
return None;
}
match table {
Err(e) => Some(Err(e)),
Ok(table) => {
// Apply filters
let post_pushdown_table = || -> crate::Result<Table> {
let table = if let Some(filters) = scan_task_filters.as_ref() {
table
.filter(&[filters.clone()])
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Apply limit if necessary, and update `&mut remaining`
if let Some(limit) = scan_task_limit {
let limited_table = if *rows_seen_so_far + table.len() > limit {
table
.slice(0, limit - *rows_seen_so_far)
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Update the rows_seen_so_far
*rows_seen_so_far += limited_table.len();

Ok(limited_table)
} else {
Ok(table)
}
}();

Some(post_pushdown_table)
}
}
});

Ok(res)
}

impl From<MicroPartition> for PyMicroPartition {
fn from(value: MicroPartition) -> Self {
Arc::new(value).into()
Expand Down
9 changes: 1 addition & 8 deletions tests/io/lancedb/test_lancedb_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
import pytest

import daft
from daft import context

native_executor_skip = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)

TABLE_NAME = "my_table"
data = {
Expand All @@ -18,8 +12,7 @@
}

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
py_arrow_skip = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0")
pytestmark = [native_executor_skip, py_arrow_skip]
pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0")


@pytest.fixture(scope="function")
Expand Down

0 comments on commit 22eab9b

Please sign in to comment.