diff --git a/Cargo.lock b/Cargo.lock index 2d1afcbe3e..3886208fc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1932,7 +1932,13 @@ name = "daft-connect" version = "0.3.0-dev0" dependencies = [ "arrow2", + "common-daft-config", "daft-core", + "daft-local-execution", + "daft-local-plan", + "daft-logical-plan", + "daft-micropartition", + "daft-scan", "daft-schema", "daft-table", "dashmap", @@ -1941,10 +1947,9 @@ dependencies = [ "pyo3", "spark-connect", "tokio", + "tokio-util", "tonic", "tracing", - "tracing-subscriber", - "tracing-tracy", "uuid 1.10.0", ] @@ -3057,19 +3062,6 @@ dependencies = [ "slab", ] -[[package]] -name = "generator" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb949699c3e4df3a183b1d2142cb24277057055ed23c68ed58894f76c517223" -dependencies = [ - "cfg-if", - "libc", - "log", - "rustversion", - "windows 0.58.0", -] - [[package]] name = "generic-array" version = "0.14.7" @@ -4040,19 +4032,6 @@ dependencies = [ "futures-sink", ] -[[package]] -name = "loom" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" -dependencies = [ - "cfg-if", - "generator", - "scoped-tls", - "tracing", - "tracing-subscriber", -] - [[package]] name = "lz4" version = "1.28.0" @@ -5620,12 +5599,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "scoped-tls" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" - [[package]] name = "scopeguard" version = "1.2.0" @@ -6162,7 +6135,7 @@ dependencies = [ "memchr", "ntapi", "rayon", - "windows 0.57.0", + "windows", ] [[package]] @@ -6618,38 +6591,6 @@ dependencies = [ "tracing-log", ] -[[package]] -name = "tracing-tracy" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc775fdaf33c3dfd19dc354729e65e87914bc67dcdc390ca1210807b8bee5902" -dependencies = [ - "tracing-core", - "tracing-subscriber", - "tracy-client", -] - -[[package]] -name = "tracy-client" -version = "0.17.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "746b078c6a09ebfd5594609049e07116735c304671eaab06ce749854d23435bc" -dependencies = [ - "loom", - "once_cell", - "tracy-client-sys", -] - -[[package]] -name = "tracy-client-sys" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3637e734239e12ab152cd269302500bd063f37624ee210cd04b4936ed671f3b1" -dependencies = [ - "cc", - "windows-targets 0.52.6", -] - [[package]] name = "try-lock" version = "0.2.5" @@ -7103,16 +7044,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows" -version = "0.58.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" -dependencies = [ - "windows-core 0.58.0", - "windows-targets 0.52.6", -] - [[package]] name = "windows-core" version = "0.52.0" @@ -7128,25 +7059,12 @@ version = "0.57.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" dependencies = [ - "windows-implement 0.57.0", - "windows-interface 0.57.0", + "windows-implement", + "windows-interface", "windows-result 0.1.2", "windows-targets 0.52.6", ] -[[package]] -name = "windows-core" -version = "0.58.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" -dependencies = [ - "windows-implement 0.58.0", - "windows-interface 0.58.0", - "windows-result 0.2.0", - "windows-strings", - "windows-targets 0.52.6", -] - [[package]] name = "windows-implement" version = "0.57.0" @@ -7158,17 +7076,6 @@ dependencies = [ "syn 2.0.87", ] -[[package]] -name = "windows-implement" -version = "0.58.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.87", -] - [[package]] name = "windows-interface" version = "0.57.0" @@ -7180,17 +7087,6 @@ dependencies = [ "syn 2.0.87", ] -[[package]] -name = "windows-interface" -version = "0.58.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.87", -] - [[package]] name = "windows-registry" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index cf16c58ca9..0c645cd5e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,7 +56,6 @@ python = [ "common-system-info/python", "daft-catalog/python", "daft-catalog-python-catalog/python", - "daft-connect/python", "daft-core/python", "daft-csv/python", "daft-dsl/python", @@ -207,6 +206,7 @@ daft-local-plan = {path = "src/daft-local-plan"} daft-logical-plan = {path = "src/daft-logical-plan"} daft-micropartition = {path = "src/daft-micropartition"} daft-physical-plan = {path = "src/daft-physical-plan"} +daft-scan = {path = "src/daft-scan"} daft-schema = {path = "src/daft-schema"} daft-sql = {path = "src/daft-sql"} daft-table = {path = "src/daft-table"} diff --git a/daft/io/range.py b/daft/io/range.py new file mode 100644 index 0000000000..76e0e8a6b9 --- /dev/null +++ b/daft/io/range.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from collections.abc import Iterator + +from daft import DataType +from daft.io._generator import GeneratorScanOperator +from daft.logical.schema import Schema +from daft.table.table import Table + + +def _range_generators(start: int, end: int, step: int) -> Iterator[Callable[[], Iterator[Table]]]: + def generator_for_value(value: int) -> Callable[[], Iterator[Table]]: + def generator() -> Iterator[Table]: + yield Table.from_pydict({"id": [value]}) + + return generator + + for value in range(start, end, step): + yield generator_for_value(value) + + +class RangeScanOperator(GeneratorScanOperator): + def __init__(self, start: int, end: int, step: int = 1) -> None: + schema = Schema._from_field_name_and_types([("id", DataType.int64())]) + + super().__init__(schema=schema, generators=_range_generators(start, end, step)) diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index e67a2da332..6fdf01d14d 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -1,22 +1,24 @@ [dependencies] +arrow2 = {workspace = true} +common-daft-config = {workspace = true, features = ["python"]} +daft-core = {workspace = true, features = ["python"]} +daft-local-execution = {workspace = true, features = ["python"]} +daft-local-plan = {workspace = true, features = ["python"]} +daft-logical-plan = {workspace = true, features = ["python"]} +daft-micropartition = {workspace = true, features = ["python"]} +daft-scan = {workspace = true, features = ["python"]} +daft-schema = {workspace = true, features = ["python"]} +daft-table = {workspace = true, features = ["python"]} dashmap = "6.1.0" eyre = "0.6.12" futures = "0.3.31" -pyo3 = {workspace = true, optional = true} +pyo3 = {workspace = true} +spark-connect = {workspace = true} tokio = {version = "1.40.0", features = ["full"]} +tokio-util = {workspace = true} tonic = "0.12.3" -tracing-subscriber = {version = "0.3.18", features = ["env-filter"]} -tracing-tracy = "0.11.3" +tracing = {workspace = true} uuid = {version = "1.10.0", features = ["v4"]} -arrow2.workspace = true -daft-core.workspace = true -daft-schema.workspace = true -daft-table.workspace = true -spark-connect.workspace = true -tracing.workspace = true - -[features] -python = ["dep:pyo3"] [lints] workspace = true diff --git a/src/daft-connect/src/convert.rs b/src/daft-connect/src/convert.rs deleted file mode 100644 index 743ffcf06a..0000000000 --- a/src/daft-connect/src/convert.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod data_conversion; -mod formatting; -mod schema_conversion; - -pub use data_conversion::convert_data; -pub use schema_conversion::connect_schema; diff --git a/src/daft-connect/src/convert/data_conversion.rs b/src/daft-connect/src/convert/data_conversion.rs deleted file mode 100644 index 71032aa4a8..0000000000 --- a/src/daft-connect/src/convert/data_conversion.rs +++ /dev/null @@ -1,61 +0,0 @@ -//! Relation handling for Spark Connect protocol. -//! -//! A Relation represents a structured dataset or transformation in Spark Connect. -//! It can be either a base relation (direct data source) or derived relation -//! (result of operations on other relations). -//! -//! The protocol represents relations as trees of operations where: -//! - Each node is a Relation with metadata and an operation type -//! - Operations can reference other relations, forming a DAG -//! - The tree describes how to derive the final result -//! -//! Example flow for: SELECT age, COUNT(*) FROM employees WHERE dept='Eng' GROUP BY age -//! -//! ```text -//! Aggregate (grouping by age) -//! ↳ Filter (department = 'Engineering') -//! ↳ Read (employees table) -//! ``` -//! -//! Relations abstract away: -//! - Physical storage details -//! - Distributed computation -//! - Query optimization -//! - Data source specifics -//! -//! This allows Spark to optimize and execute queries efficiently across a cluster -//! while providing a consistent API regardless of the underlying data source. -//! ```mermaid -//! -//! ``` - -use eyre::{eyre, Context}; -use futures::Stream; -use spark_connect::{relation::RelType, ExecutePlanResponse, Relation}; -use tracing::trace; - -use crate::convert::formatting::RelTypeExt; - -mod range; -use range::range; - -use crate::command::PlanIds; - -pub fn convert_data( - plan: Relation, - context: &PlanIds, -) -> eyre::Result> + Unpin> { - // First check common fields if needed - if let Some(common) = &plan.common { - // contains metadata shared across all relation types - // Log or handle common fields if necessary - trace!("Processing relation with plan_id: {:?}", common.plan_id); - } - - let rel_type = plan.rel_type.ok_or_else(|| eyre!("rel_type is None"))?; - - match rel_type { - RelType::Range(input) => range(input, context).wrap_err("parsing Range"), - other => Err(eyre!("Unsupported top-level relation: {}", other.name())), - } -} diff --git a/src/daft-connect/src/convert/data_conversion/range.rs b/src/daft-connect/src/convert/data_conversion/range.rs deleted file mode 100644 index f370228188..0000000000 --- a/src/daft-connect/src/convert/data_conversion/range.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::future::ready; - -use daft_core::prelude::Series; -use daft_schema::prelude::Schema; -use daft_table::Table; -use eyre::{ensure, Context}; -use futures::{stream, Stream}; -use spark_connect::{ExecutePlanResponse, Range}; - -use crate::command::PlanIds; - -pub fn range( - range: Range, - channel: &PlanIds, -) -> eyre::Result> + Unpin> { - let Range { - start, - end, - step, - num_partitions, - } = range; - - let start = start.unwrap_or(0); - - ensure!(num_partitions.is_none(), "num_partitions is not supported"); - - let step = usize::try_from(step).wrap_err("step must be a positive integer")?; - ensure!(step > 0, "step must be greater than 0"); - - let arrow_array: arrow2::array::Int64Array = (start..end).step_by(step).map(Some).collect(); - let len = arrow_array.len(); - - let singleton_series = Series::try_from(( - "range", - Box::new(arrow_array) as Box, - )) - .wrap_err("creating singleton series")?; - - let singleton_table = Table::new_with_size( - Schema::new(vec![singleton_series.field().clone()])?, - vec![singleton_series], - len, - )?; - - let response = channel.gen_response(&singleton_table)?; - - Ok(stream::once(ready(Ok(response)))) -} diff --git a/src/daft-connect/src/convert/formatting.rs b/src/daft-connect/src/convert/formatting.rs deleted file mode 100644 index 3310a918fb..0000000000 --- a/src/daft-connect/src/convert/formatting.rs +++ /dev/null @@ -1,69 +0,0 @@ -use spark_connect::relation::RelType; - -/// Extension trait for RelType to add a `name` method. -pub trait RelTypeExt { - /// Returns the name of the RelType as a string. - fn name(&self) -> &'static str; -} - -impl RelTypeExt for RelType { - fn name(&self) -> &'static str { - match self { - Self::Read(_) => "Read", - Self::Project(_) => "Project", - Self::Filter(_) => "Filter", - Self::Join(_) => "Join", - Self::SetOp(_) => "SetOp", - Self::Sort(_) => "Sort", - Self::Limit(_) => "Limit", - Self::Aggregate(_) => "Aggregate", - Self::Sql(_) => "Sql", - Self::LocalRelation(_) => "LocalRelation", - Self::Sample(_) => "Sample", - Self::Offset(_) => "Offset", - Self::Deduplicate(_) => "Deduplicate", - Self::Range(_) => "Range", - Self::SubqueryAlias(_) => "SubqueryAlias", - Self::Repartition(_) => "Repartition", - Self::ToDf(_) => "ToDf", - Self::WithColumnsRenamed(_) => "WithColumnsRenamed", - Self::ShowString(_) => "ShowString", - Self::Drop(_) => "Drop", - Self::Tail(_) => "Tail", - Self::WithColumns(_) => "WithColumns", - Self::Hint(_) => "Hint", - Self::Unpivot(_) => "Unpivot", - Self::ToSchema(_) => "ToSchema", - Self::RepartitionByExpression(_) => "RepartitionByExpression", - Self::MapPartitions(_) => "MapPartitions", - Self::CollectMetrics(_) => "CollectMetrics", - Self::Parse(_) => "Parse", - Self::GroupMap(_) => "GroupMap", - Self::CoGroupMap(_) => "CoGroupMap", - Self::WithWatermark(_) => "WithWatermark", - Self::ApplyInPandasWithState(_) => "ApplyInPandasWithState", - Self::HtmlString(_) => "HtmlString", - Self::CachedLocalRelation(_) => "CachedLocalRelation", - Self::CachedRemoteRelation(_) => "CachedRemoteRelation", - Self::CommonInlineUserDefinedTableFunction(_) => "CommonInlineUserDefinedTableFunction", - Self::AsOfJoin(_) => "AsOfJoin", - Self::CommonInlineUserDefinedDataSource(_) => "CommonInlineUserDefinedDataSource", - Self::WithRelations(_) => "WithRelations", - Self::Transpose(_) => "Transpose", - Self::FillNa(_) => "FillNa", - Self::DropNa(_) => "DropNa", - Self::Replace(_) => "Replace", - Self::Summary(_) => "Summary", - Self::Crosstab(_) => "Crosstab", - Self::Describe(_) => "Describe", - Self::Cov(_) => "Cov", - Self::Corr(_) => "Corr", - Self::ApproxQuantile(_) => "ApproxQuantile", - Self::FreqItems(_) => "FreqItems", - Self::SampleBy(_) => "SampleBy", - Self::Catalog(_) => "Catalog", - Self::Extension(_) => "Extension", - Self::Unknown(_) => "Unknown", - } - } -} diff --git a/src/daft-connect/src/err.rs b/src/daft-connect/src/err.rs index 0cf065287f..4e0377912c 100644 --- a/src/daft-connect/src/err.rs +++ b/src/daft-connect/src/err.rs @@ -13,3 +13,12 @@ macro_rules! unimplemented_err { Err(::tonic::Status::unimplemented(msg)) }}; } + +// not found +#[macro_export] +macro_rules! not_found_err { + ($arg: tt) => {{ + let msg = format!($arg); + Err(::tonic::Status::not_found(msg)) + }}; +} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index d9f2c2f6ad..1f9a11f3bf 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -5,11 +5,9 @@ #![feature(iter_from_coroutine)] #![feature(stmt_expr_attributes)] #![feature(try_trait_v2_residual)] -#![deny(unused)] use dashmap::DashMap; use eyre::Context; -#[cfg(feature = "python")] use pyo3::types::PyModuleMethods; use spark_connect::{ analyze_plan_response, @@ -28,19 +26,19 @@ use uuid::Uuid; use crate::session::Session; -mod command; mod config; -mod convert; mod err; +mod op; mod session; +mod translation; pub mod util; -#[cfg_attr(feature = "python", pyo3::pyclass)] +#[pyo3::pyclass] pub struct ConnectionHandle { shutdown_signal: Option>, } -#[cfg_attr(feature = "python", pyo3::pymethods)] +#[pyo3::pymethods] impl ConnectionHandle { pub fn shutdown(&mut self) { let Some(shutdown_signal) = self.shutdown_signal.take() else { @@ -283,7 +281,14 @@ impl SparkConnectService for DaftSparkConnectService { return Err(Status::invalid_argument("op_type is required to be root")); }; - let result = convert::connect_schema(relation)?; + let result = match translation::relation_to_schema(relation) { + Ok(schema) => schema, + Err(e) => { + return invalid_argument_err!( + "Failed to translate relation to schema: {e}" + ); + } + }; let schema = analyze_plan_response::DdlParse { parsed: Some(result), @@ -354,14 +359,13 @@ impl SparkConnectService for DaftSparkConnectService { unimplemented_err!("fetch_error_details operation is not yet implemented") } } -#[cfg(feature = "python")] + #[pyo3::pyfunction] #[pyo3(name = "connect_start")] pub fn py_connect_start(addr: &str) -> pyo3::PyResult { start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}"))) } -#[cfg(feature = "python")] pub fn register_modules(parent: &pyo3::Bound) -> pyo3::PyResult<()> { parent.add_function(pyo3::wrap_pyfunction_bound!(py_connect_start, parent)?)?; parent.add_class::()?; diff --git a/src/daft-connect/src/main.rs b/src/daft-connect/src/main.rs deleted file mode 100644 index 249938896c..0000000000 --- a/src/daft-connect/src/main.rs +++ /dev/null @@ -1,42 +0,0 @@ -use daft_connect::DaftSparkConnectService; -use spark_connect::spark_connect_service_server::SparkConnectServiceServer; -use tonic::transport::Server; -use tracing::info; -use tracing_subscriber::{layer::SubscriberExt, Registry}; -use tracing_tracy::TracyLayer; - -fn setup_tracing() { - tracing::subscriber::set_global_default( - Registry::default().with(TracyLayer::default()).with( - tracing_subscriber::fmt::layer() - .with_target(false) - .with_thread_ids(false) - .with_file(true) - .with_line_number(true), - ), - ) - .expect("setup tracing subscribers"); -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - setup_tracing(); - - let addr = "[::1]:50051".parse()?; - let service = DaftSparkConnectService::default(); - - info!("Daft-Connect server listening on {}", addr); - - tokio::select! { - result = Server::builder() - .add_service(SparkConnectServiceServer::new(service)) - .serve(addr) => { - result?; - } - _ = tokio::signal::ctrl_c() => { - info!("\nReceived Ctrl-C, gracefully shutting down server"); - } - } - - Ok(()) -} diff --git a/src/daft-connect/src/op.rs b/src/daft-connect/src/op.rs new file mode 100644 index 0000000000..2e8bdddf98 --- /dev/null +++ b/src/daft-connect/src/op.rs @@ -0,0 +1 @@ +pub mod execute; diff --git a/src/daft-connect/src/command.rs b/src/daft-connect/src/op/execute.rs similarity index 63% rename from src/daft-connect/src/command.rs rename to src/daft-connect/src/op/execute.rs index 28ddac2365..fba3cc850d 100644 --- a/src/daft-connect/src/command.rs +++ b/src/daft-connect/src/op/execute.rs @@ -1,20 +1,18 @@ -use std::future::ready; - use arrow2::io::ipc::write::StreamWriter; use daft_table::Table; use eyre::Context; -use futures::stream; use spark_connect::{ execute_plan_response::{ArrowBatch, ResponseType, ResultComplete}, spark_connect_service_server::SparkConnectService, - ExecutePlanResponse, Relation, + ExecutePlanResponse, }; -use tonic::Status; use uuid::Uuid; -use crate::{convert::convert_data, DaftSparkConnectService, Session}; +use crate::{DaftSparkConnectService, Session}; + +mod root; -type DaftStream = ::ExecutePlanStream; +pub type ExecuteStream = ::ExecutePlanStream; pub struct PlanIds { session: String, @@ -23,6 +21,32 @@ pub struct PlanIds { } impl PlanIds { + pub fn new( + client_side_session_id: impl Into, + server_side_session_id: impl Into, + ) -> Self { + let client_side_session_id = client_side_session_id.into(); + let server_side_session_id = server_side_session_id.into(); + Self { + session: client_side_session_id, + server_side_session: server_side_session_id, + operation: Uuid::new_v4().to_string(), + } + } + + pub fn finished(&self) -> ExecutePlanResponse { + ExecutePlanResponse { + session_id: self.session.to_string(), + server_side_session_id: self.server_side_session.to_string(), + operation_id: self.operation.to_string(), + response_id: Uuid::new_v4().to_string(), + metrics: None, + observed_metrics: vec![], + schema: None, + response_type: Some(ResponseType::ResultComplete(ResultComplete {})), + } + } + pub fn gen_response(&self, table: &Table) -> eyre::Result { let mut data = Vec::new(); @@ -68,37 +92,4 @@ impl PlanIds { } } -impl Session { - pub async fn handle_root_command( - &self, - command: Relation, - operation_id: String, - ) -> Result { - use futures::{StreamExt, TryStreamExt}; - - let context = PlanIds { - session: self.client_side_session_id().to_string(), - server_side_session: self.server_side_session_id().to_string(), - operation: operation_id.clone(), - }; - - let finished = ExecutePlanResponse { - session_id: self.client_side_session_id().to_string(), - server_side_session_id: self.server_side_session_id().to_string(), - operation_id, - response_id: Uuid::new_v4().to_string(), - metrics: None, - observed_metrics: vec![], - schema: None, - response_type: Some(ResponseType::ResultComplete(ResultComplete {})), - }; - - let stream = convert_data(command, &context) - .map_err(|e| Status::internal(e.to_string()))? - .chain(stream::once(ready(Ok(finished)))); - - Ok(Box::pin( - stream.map_err(|e| Status::internal(e.to_string())), - )) - } -} +impl Session {} diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs new file mode 100644 index 0000000000..6493768ef6 --- /dev/null +++ b/src/daft-connect/src/op/execute/root.rs @@ -0,0 +1,107 @@ +use std::{collections::HashMap, future::ready}; + +use common_daft_config::DaftExecutionConfig; +use futures::stream; +use spark_connect::{ExecutePlanResponse, Relation}; +use tokio_util::sync::CancellationToken; +use tonic::{codegen::tokio_stream::wrappers::UnboundedReceiverStream, Status}; + +use crate::{ + op::execute::{ExecuteStream, PlanIds}, + session::Session, + translation, +}; + +impl Session { + pub async fn handle_root_command( + &self, + command: Relation, + operation_id: String, + ) -> Result { + use futures::{StreamExt, TryStreamExt}; + + let context = PlanIds { + session: self.client_side_session_id().to_string(), + server_side_session: self.server_side_session_id().to_string(), + operation: operation_id, + }; + + let finished = context.finished(); + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); + + std::thread::spawn(move || { + let plan = match translation::to_logical_plan(command) { + Ok(plan) => plan, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + let logical_plan = plan.build(); + let physical_plan = match daft_local_plan::translate(&logical_plan) { + Ok(plan) => plan, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + let cfg = DaftExecutionConfig::default(); + let result = match daft_local_execution::run_local( + &physical_plan, + HashMap::new(), + cfg.into(), + None, + CancellationToken::new(), // todo: maybe implement cancelling + ) { + Ok(result) => result, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + for result in result { + let result = match result { + Ok(result) => result, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + let tables = match result.get_tables() { + Ok(tables) => tables, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + for table in tables.as_slice() { + let response = context.gen_response(table); + + let response = match response { + Ok(response) => response, + Err(e) => { + tx.send(Err(eyre::eyre!(e))).unwrap(); + return; + } + }; + + tx.send(Ok(response)).unwrap(); + } + } + }); + + let stream = UnboundedReceiverStream::new(rx); + + let stream = stream + .map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))) + .chain(stream::once(ready(Ok(finished)))); + + Ok(Box::pin(stream)) + } +} diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs new file mode 100644 index 0000000000..125aa6e884 --- /dev/null +++ b/src/daft-connect/src/translation.rs @@ -0,0 +1,7 @@ +//! Translation between Spark Connect and Daft + +mod logical_plan; +mod schema; + +pub use logical_plan::to_logical_plan; +pub use schema::relation_to_schema; diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs new file mode 100644 index 0000000000..3a4f88af7c --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -0,0 +1,62 @@ +use daft_logical_plan::LogicalPlanBuilder; +use daft_scan::python::pylib::ScanOperatorHandle; +use eyre::{bail, ensure, Context}; +use pyo3::prelude::*; +use spark_connect::{relation::RelType, Range, Relation}; +use tracing::warn; + +pub fn to_logical_plan(relation: Relation) -> eyre::Result { + if let Some(common) = relation.common { + warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + }; + + let Some(rel_type) = relation.rel_type else { + bail!("Relation type is required"); + }; + + match rel_type { + RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"), + plan => bail!("Unsupported relation type: {plan:?}"), + } +} + +fn range(range: Range) -> eyre::Result { + let Range { + start, + end, + step, + num_partitions, + } = range; + + if let Some(partitions) = num_partitions { + warn!("{partitions} ignored"); + } + + let start = start.unwrap_or(0); + + let step = usize::try_from(step).wrap_err("step must be a positive integer")?; + ensure!(step > 0, "step must be greater than 0"); + + let plan = Python::with_gil(|py| { + let range_module = PyModule::import_bound(py, "daft.io.range") + .wrap_err("Failed to import daft.io.range")?; + + let range = range_module + .getattr(pyo3::intern!(py, "RangeScanOperator")) + .wrap_err("Failed to get range function")?; + + let range = range + .call1((start, end, step)) + .wrap_err("Failed to create range scan operator")? + .to_object(py); + + let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(range, py)?; + + let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; + + eyre::Result::<_>::Ok(plan) + }) + .wrap_err("Failed to create range scan")?; + + Ok(plan) +} diff --git a/src/daft-connect/src/convert/schema_conversion.rs b/src/daft-connect/src/translation/schema.rs similarity index 73% rename from src/daft-connect/src/convert/schema_conversion.rs rename to src/daft-connect/src/translation/schema.rs index dcce376b94..de28a587fc 100644 --- a/src/daft-connect/src/convert/schema_conversion.rs +++ b/src/daft-connect/src/translation/schema.rs @@ -1,13 +1,15 @@ +use eyre::bail; use spark_connect::{ data_type::{Kind, Long, Struct, StructField}, relation::RelType, DataType, Relation, }; +use tracing::warn; #[tracing::instrument(skip_all)] -pub fn connect_schema(input: Relation) -> Result { +pub fn relation_to_schema(input: Relation) -> eyre::Result { if input.common.is_some() { - tracing::warn!("We do not currently look at common fields"); + warn!("We do not currently look at common fields"); } let result = match input @@ -16,9 +18,7 @@ pub fn connect_schema(input: Relation) -> Result { { RelType::Range(spark_connect::Range { num_partitions, .. }) => { if num_partitions.is_some() { - return Err(tonic::Status::unimplemented( - "num_partitions is not supported", - )); + warn!("We do not currently support num_partitions"); } let long = Long { @@ -46,9 +46,7 @@ pub fn connect_schema(input: Relation) -> Result { } } other => { - return Err(tonic::Status::unimplemented(format!( - "Unsupported relation type: {other:?}" - ))) + bail!("Unsupported relation type: {other:?}"); } }; diff --git a/src/daft-local-execution/src/sources/range.rs b/src/daft-local-execution/src/sources/range.rs new file mode 100644 index 0000000000..b723c7c492 --- /dev/null +++ b/src/daft-local-execution/src/sources/range.rs @@ -0,0 +1,99 @@ +use std::sync::Arc; + +use arrow2::array::Array; +use async_trait::async_trait; +use common_error::DaftResult; +use daft_core::{ + datatypes::{DataType, Field}, + prelude::{Schema, SchemaRef}, + series::Series, +}; +use daft_io::IOStatsRef; +use daft_micropartition::MicroPartition; +use daft_table::Table; +use futures::stream::{self, StreamExt}; +use tracing::instrument; + +use crate::sources::source::{Source, SourceStream}; + +pub struct RangeSource { + pub start: i64, + pub end: i64, + pub step: usize, + pub num_partitions: usize, + pub schema: SchemaRef, +} + +fn to_micropartition(start: i64, end: i64, step: usize) -> DaftResult { + let values: Vec<_> = (start..end).step_by(step).map(Some).collect(); + let len = values.len(); + + let field = Field::new("value", DataType::Int64); + let schema = Schema::new(vec![field.clone()])?; + + let field = Arc::new(field); + + let int_array = arrow2::array::Int64Array::from_iter(values); + let arrow_array: Box = Box::new(int_array); + let series = Series::from_arrow(field, arrow_array)?; + + let table = Table::new_unchecked(schema.clone(), vec![series], len); + Ok(MicroPartition::new_loaded( + schema.into(), + vec![table].into(), + None, + )) +} + +impl RangeSource { + pub fn new(start: i64, end: i64, step: usize, num_partitions: usize) -> Self { + let field = Field::new("value", DataType::Int64); + let schema = Schema::new(vec![field]).unwrap(); + Self { + start, + end, + step, + num_partitions, + schema: Arc::new(schema), + } + } + + pub fn arced(self) -> Arc { + Arc::new(self) as Arc + } +} + +#[async_trait] +impl Source for RangeSource { + #[instrument(name = "RangeSource::get_data", level = "info", skip_all)] + async fn get_data( + &self, + _maintain_order: bool, + _io_stats: IOStatsRef, + ) -> DaftResult> { + let total_elements = ((self.end - self.start) as usize + self.step - 1) / self.step; + let elements_per_partition = + (total_elements + self.num_partitions - 1) / self.num_partitions; + + let step = self.step; + let start = self.start; + let end = self.end; + + let partitions = (0..self.num_partitions).map(move |i| { + let start = start + (i * elements_per_partition * step) as i64; + let end = end.min(start + (elements_per_partition * step) as i64); + + to_micropartition(start, end, step).map(Arc::new) + }); + + Ok(Box::pin(stream::iter(partitions))) + } + + fn name(&self) -> &'static str { + "Range" + } + + fn schema(&self) -> &SchemaRef { + &self.schema + } +} diff --git a/src/daft-micropartition/Cargo.toml b/src/daft-micropartition/Cargo.toml index 99877353d7..e7ce7f5ac8 100644 --- a/src/daft-micropartition/Cargo.toml +++ b/src/daft-micropartition/Cargo.toml @@ -19,7 +19,7 @@ pyo3 = {workspace = true, optional = true} snafu = {workspace = true} [features] -python = ["dep:pyo3", "common-error/python", "common-file-formats/python", "common-scan-info/python", "daft-core/python", "daft-dsl/python", "daft-table/python", "daft-io/python", "daft-parquet/python", "daft-scan/python", "daft-stats/python"] +python = ["dep:pyo3", "common-error/python", "common-file-formats/python", "common-scan-info/python", "daft-core/python", "daft-dsl/python", "daft-table/python", "daft-io/python", "daft-parquet/python", "daft-scan/python", "daft-stats/python", "daft-csv/python", "daft-json/python"] [lints] workspace = true diff --git a/src/daft-schema/src/schema.rs b/src/daft-schema/src/schema.rs index a36f6abe02..29e083a533 100644 --- a/src/daft-schema/src/schema.rs +++ b/src/daft-schema/src/schema.rs @@ -17,7 +17,7 @@ use crate::field::Field; pub type SchemaRef = Arc; -#[derive(Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Display, PartialEq, Eq, Serialize, Deserialize, Clone)] #[serde(transparent)] #[display("{}\n", make_schema_vertical_table( fields.iter().map(|(name, field)| (name.clone(), field.dtype.to_string())) diff --git a/tests/connect/test_range_simple.py b/tests/connect/test_range_simple.py index 6eab95102f..3fcef9b874 100644 --- a/tests/connect/test_range_simple.py +++ b/tests/connect/test_range_simple.py @@ -1,7 +1,6 @@ from __future__ import annotations -import time - +# import time import pytest from pyspark.sql import SparkSession @@ -22,7 +21,7 @@ def spark_session(): # Cleanup server.shutdown() session.stop() - time.sleep(2) # Allow time for session cleanup + # time.sleep(2) # Allow time for session cleanup def test_range_operation(spark_session): @@ -35,4 +34,4 @@ def test_range_operation(spark_session): # Verify the DataFrame has expected values assert len(pandas_df) == 10, "DataFrame should have 10 rows" - assert list(pandas_df["range"]) == list(range(10)), "DataFrame should contain values 0-9" + assert list(pandas_df["id"]) == list(range(10)), "DataFrame should contain values 0-9"