Skip to content

Commit

Permalink
Merge branch 'main' into lift-stats-to-logical-plan
Browse files Browse the repository at this point in the history
  • Loading branch information
desmondcheongzx authored Nov 21, 2024
2 parents 0792408 + 3394a66 commit 574ad0a
Show file tree
Hide file tree
Showing 36 changed files with 1,197 additions and 418 deletions.
6 changes: 4 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,14 @@ chrono-tz = "0.8.4"
comfy-table = "7.1.1"
common-daft-config = {path = "src/common/daft-config"}
common-error = {path = "src/common/error", default-features = false}
daft-core = {path = "src/daft-core"}
daft-dsl = {path = "src/daft-dsl"}
daft-hash = {path = "src/daft-hash"}
daft-local-execution = {path = "src/daft-local-execution"}
daft-local-plan = {path = "src/daft-local-plan"}
daft-logical-plan = {path = "src/daft-logical-plan"}
daft-scan = {path = "src/daft-scan"}
daft-schema = {path = "src/daft-schema"}
daft-table = {path = "src/daft-table"}
derivative = "2.2.0"
derive_builder = "0.20.2"
Expand Down
3 changes: 2 additions & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1208,10 +1208,11 @@ def sql_expr(sql: str) -> PyExpr: ...
def list_sql_functions() -> list[SQLFunctionStub]: ...
def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ...
def to_struct(inputs: list[PyExpr]) -> PyExpr: ...
def connect_start(addr: str) -> ConnectionHandle: ...
def connect_start(addr: str = "sc://0.0.0.0:0") -> ConnectionHandle: ...

class ConnectionHandle:
def shutdown(self) -> None: ...
def port(self) -> int: ...

# expr numeric ops
def abs(expr: PyExpr) -> PyExpr: ...
Expand Down
19 changes: 13 additions & 6 deletions src/common/py-serde/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,19 @@ impl<'de> Visitor<'de> for PyObjectVisitor {
where
E: DeError,
{
Python::with_gil(|py| {
py.import_bound(pyo3::intern!(py, "daft.pickle"))
.and_then(|m| m.getattr(pyo3::intern!(py, "loads")))
.and_then(|f| Ok(f.call1((v,))?.into()))
.map_err(|e| DeError::custom(e.to_string()))
})
self.visit_bytes(&v)
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut v: Vec<u8> = Vec::with_capacity(seq.size_hint().unwrap_or_default());
while let Some(elem) = seq.next_element()? {
v.push(elem);
}

self.visit_bytes(&v)
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
[dependencies]
arrow2 = {workspace = true}
async-stream = "0.3.6"
common-daft-config = {workspace = true}
daft-core = {workspace = true}
daft-dsl = {workspace = true}
daft-local-execution = {workspace = true}
daft-local-plan = {workspace = true}
daft-logical-plan = {workspace = true}
daft-scan = {workspace = true}
daft-schema = {workspace = true}
daft-table = {workspace = true}
dashmap = "6.1.0"
eyre = "0.6.12"
futures = "0.3.31"
pyo3 = {workspace = true, optional = true}
spark-connect = {workspace = true}
tokio = {version = "1.40.0", features = ["full"]}
tokio-util = {workspace = true}
tonic = "0.12.3"
tracing = {workspace = true}
uuid = {version = "1.10.0", features = ["v4"]}

[features]
python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-local-plan/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python"]
python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python", "daft-dsl/python", "daft-schema/python", "daft-core/python"]

[lints]
workspace = true
Expand Down
63 changes: 43 additions & 20 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use spark_connect::{
ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse,
};
use tonic::{transport::Server, Request, Response, Status};
use tracing::info;
use tracing::{debug, info};
use uuid::Uuid;

use crate::session::Session;
Expand All @@ -37,6 +37,7 @@ pub mod util;
#[cfg_attr(feature = "python", pyo3::pyclass)]
pub struct ConnectionHandle {
shutdown_signal: Option<tokio::sync::oneshot::Sender<()>>,
port: u16,
}

#[cfg_attr(feature = "python", pyo3::pymethods)]
Expand All @@ -47,12 +48,19 @@ impl ConnectionHandle {
};
shutdown_signal.send(()).unwrap();
}

pub fn port(&self) -> u16 {
self.port
}
}

pub fn start(addr: &str) -> eyre::Result<ConnectionHandle> {
info!("Daft-Connect server listening on {addr}");
let addr = util::parse_spark_connect_address(addr)?;

let listener = std::net::TcpListener::bind(addr)?;
let port = listener.local_addr()?.port();

let service = DaftSparkConnectService::default();

info!("Daft-Connect server listening on {addr}");
Expand All @@ -61,25 +69,40 @@ pub fn start(addr: &str) -> eyre::Result<ConnectionHandle> {

let handle = ConnectionHandle {
shutdown_signal: Some(shutdown_signal),
port,
};

std::thread::spawn(move || {
let runtime = tokio::runtime::Runtime::new().unwrap();
let result = runtime
.block_on(async {
tokio::select! {
result = Server::builder()
.add_service(SparkConnectServiceServer::new(service))
.serve(addr) => {
result
}
_ = shutdown_receiver => {
info!("Received shutdown signal");
Ok(())
let result = runtime.block_on(async {
let incoming = {
let listener = tokio::net::TcpListener::from_std(listener)
.wrap_err("Failed to create TcpListener from std::net::TcpListener")?;

async_stream::stream! {
loop {
match listener.accept().await {
Ok((stream, _)) => yield Ok(stream),
Err(e) => yield Err(e),
}
}
}
})
.wrap_err_with(|| format!("Failed to start server on {addr}"));
};

let result = tokio::select! {
result = Server::builder()
.add_service(SparkConnectServiceServer::new(service))
.serve_with_incoming(incoming)=> {
result
}
_ = shutdown_receiver => {
info!("Received shutdown signal");
Ok(())
}
};

result.wrap_err_with(|| format!("Failed to start server on {addr}"))
});

if let Err(e) = result {
eprintln!("Daft-Connect server error: {e:?}");
Expand Down Expand Up @@ -286,22 +309,22 @@ impl SparkConnectService for DaftSparkConnectService {
Ok(schema) => schema,
Err(e) => {
return invalid_argument_err!(
"Failed to translate relation to schema: {e}"
"Failed to translate relation to schema: {e:?}"
);
}
};

let schema = analyze_plan_response::DdlParse {
parsed: Some(result),
let schema = analyze_plan_response::Schema {
schema: Some(result),
};

let response = AnalyzePlanResponse {
session_id,
server_side_session_id: String::new(),
result: Some(analyze_plan_response::Result::DdlParse(schema)),
result: Some(analyze_plan_response::Result::Schema(schema)),
};

println!("response: {response:#?}");
debug!("response: {response:#?}");

Ok(Response::new(response))
}
Expand Down Expand Up @@ -363,7 +386,7 @@ impl SparkConnectService for DaftSparkConnectService {

#[cfg(feature = "python")]
#[pyo3::pyfunction]
#[pyo3(name = "connect_start")]
#[pyo3(name = "connect_start", signature = (addr = "sc://0.0.0.0:0"))]
pub fn py_connect_start(addr: &str) -> pyo3::PyResult<ConnectionHandle> {
start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}")))
}
Expand Down
37 changes: 16 additions & 21 deletions src/daft-connect/src/op/execute/root.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::{collections::HashMap, future::ready};

use common_daft_config::DaftExecutionConfig;
use daft_local_execution::NativeExecutor;
use futures::stream;
use spark_connect::{ExecutePlanResponse, Relation};
use tokio_util::sync::CancellationToken;
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};

use crate::{
Expand All @@ -28,37 +28,32 @@ impl Session {

let finished = context.finished();

let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(16);
std::thread::spawn(move || {
let result = (|| -> eyre::Result<()> {
let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(1);
tokio::spawn(async move {
let execution_fut = async {
let plan = translation::to_logical_plan(command)?;
let logical_plan = plan.build();
// TODO(desmond): It looks like we don't currently do optimizer passes here before translation.
let physical_plan = daft_local_plan::translate(&logical_plan)?;

let optimized_plan = plan.optimize()?;
let cfg = DaftExecutionConfig::default();
let results = daft_local_execution::run_local(
&physical_plan,
HashMap::new(),
cfg.into(),
None,
CancellationToken::new(), // todo: maybe implement cancelling
)?;
let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?;
let mut result_stream = native_executor
.run(HashMap::new(), cfg.into(), None)?
.into_stream();

for result in results {
while let Some(result) = result_stream.next().await {
let result = result?;
let tables = result.get_tables()?;

for table in tables.as_slice() {
let response = context.gen_response(table)?;
tx.blocking_send(Ok(response)).unwrap();
if tx.send(Ok(response)).await.is_err() {
return Ok(());
}
}
}
Ok(())
})();
};

if let Err(e) = result {
tx.blocking_send(Err(e)).unwrap();
if let Err(e) = execution_fut.await {
let _ = tx.send(Err(e)).await;
}
});

Expand Down
6 changes: 6 additions & 0 deletions src/daft-connect/src/translation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
//! Translation between Spark Connect and Daft
mod datatype;
mod expr;
mod literal;
mod logical_plan;
mod schema;

pub use datatype::to_spark_datatype;
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
pub use logical_plan::to_logical_plan;
pub use schema::relation_to_schema;
Loading

0 comments on commit 574ad0a

Please sign in to comment.