Skip to content

Commit

Permalink
[FEAT] connect: explain (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 11, 2024
1 parent 998969b commit 4ea8e55
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 3 deletions.
34 changes: 32 additions & 2 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use spark_connect::{
use tonic::{transport::Server, Request, Response, Status};
use tracing::info;
use uuid::Uuid;

use spark_connect::analyze_plan_request::explain::ExplainMode;
use crate::session::Session;

mod config;
Expand Down Expand Up @@ -282,6 +282,8 @@ impl SparkConnectService for DaftSparkConnectService {
use spark_connect::analyze_plan_request::*;
let request = request.into_inner();

let mut session = self.get_session(&request.session_id)?;

let AnalyzePlanRequest {
session_id,
analyze,
Expand Down Expand Up @@ -323,7 +325,35 @@ impl SparkConnectService for DaftSparkConnectService {

Ok(Response::new(response))
}
_ => unimplemented_err!("Analyze plan operation is not yet implemented"),
Analyze::Explain(explain) => {
let Explain { plan, explain_mode } = explain;

let explain_mode = ExplainMode::try_from(explain_mode)
.map_err(|_| invalid_argument_err!("Invalid Explain Mode"))?;

let Some(plan) = plan else {
return invalid_argument_err!("Plan is required");
};

let Some(plan) = plan.op_type else {
return invalid_argument_err!("Op Type is required");
};

let OpType::Root(relation) = plan else {
return invalid_argument_err!("Plan operation is required");
};

let result = match session.handle_explain_command(relation, explain_mode).await {
Ok(result) => result,
Err(e) => return Err(Status::internal(format!("Error in Daft server: {e:?}"))),
};

Ok(Response::new(result))
}
op => {
println!("{op:#?}");
unimplemented_err!("Analyze plan operation is not yet implemented")
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/src/op.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod execute;
pub mod analyze;
52 changes: 52 additions & 0 deletions src/daft-connect/src/op/analyze.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::pin::Pin;

use spark_connect::{analyze_plan_response, AnalyzePlanResponse};

pub type AnalyzeStream =
Pin<Box<dyn futures::Stream<Item = Result<AnalyzePlanResponse, Status>> + Send + Sync>>;

use spark_connect::{analyze_plan_request::explain::ExplainMode, Relation};
use tonic::Status;

use crate::{session::Session, translation};

pub struct PlanIds {
session: String,
server_side_session: String,
}

impl PlanIds {
pub fn response(&self, result: analyze_plan_response::Result) -> AnalyzePlanResponse {
AnalyzePlanResponse {
session_id: self.session.to_string(),
server_side_session_id: self.server_side_session.to_string(),
result: Some(result),
}
}
}

impl Session {
pub async fn handle_explain_command(
&self,
command: Relation,
_mode: ExplainMode,
) -> eyre::Result<AnalyzePlanResponse> {
let context = PlanIds {
session: self.client_side_session_id().to_string(),
server_side_session: self.server_side_session_id().to_string(),
};

let plan = translation::to_logical_plan(command)?;
let optimized_plan = plan.optimize()?;

let optimized_plan = optimized_plan.build();

// todo: what do we want this to display
let explain_string = format!("{optimized_plan}");

let schema = analyze_plan_response::Explain { explain_string };

let response = context.response(analyze_plan_response::Result::Explain(schema));
Ok(response)
}
}
2 changes: 1 addition & 1 deletion src/daft-connect/src/op/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod write;

pub type ExecuteStream = <DaftSparkConnectService as SparkConnectService>::ExecutePlanStream;

pub struct PlanIds {
struct PlanIds {
session: String,
server_side_session: String,
operation: String,
Expand Down
16 changes: 16 additions & 0 deletions tests/connect/test_explain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations


def test_explain(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Union the two ranges
unioned = range1.union(range2)

# Get the explain plan
explain_str = unioned.explain(extended=True)

# Verify explain output contains expected elements
print(explain_str)

0 comments on commit 4ea8e55

Please sign in to comment.