Skip to content

Commit

Permalink
[FEAT] connect: add alias support
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 20, 2024
1 parent 7ba8185 commit 0682b6d
Show file tree
Hide file tree
Showing 15 changed files with 555 additions and 100 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,14 @@ chrono-tz = "0.8.4"
comfy-table = "7.1.1"
common-daft-config = {path = "src/common/daft-config"}
common-error = {path = "src/common/error", default-features = false}
daft-core = {path = "src/daft-core"}
daft-dsl = {path = "src/daft-dsl"}
daft-hash = {path = "src/daft-hash"}
daft-local-execution = {path = "src/daft-local-execution"}
daft-local-plan = {path = "src/daft-local-plan"}
daft-logical-plan = {path = "src/daft-logical-plan"}
daft-scan = {path = "src/daft-scan"}
daft-schema = {path = "src/daft-schema"}
daft-table = {path = "src/daft-table"}
derivative = "2.2.0"
derive_builder = "0.20.2"
Expand Down
5 changes: 4 additions & 1 deletion src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
arrow2 = {workspace = true}
async-stream = "0.3.6"
common-daft-config = {workspace = true}
daft-core = {workspace = true}
daft-dsl = {workspace = true}
daft-local-execution = {workspace = true}
daft-local-plan = {workspace = true}
daft-logical-plan = {workspace = true}
daft-scan = {workspace = true}
daft-schema = {workspace = true}
daft-table = {workspace = true}
dashmap = "6.1.0"
eyre = "0.6.12"
Expand All @@ -19,7 +22,7 @@ tracing = {workspace = true}
uuid = {version = "1.10.0", features = ["v4"]}

[features]
python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-local-plan/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python"]
python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-local-plan/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python", "daft-dsl/python", "daft-schema/python", "daft-core/python"]

[lints]
workspace = true
Expand Down
12 changes: 6 additions & 6 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use spark_connect::{
ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse,
};
use tonic::{transport::Server, Request, Response, Status};
use tracing::info;
use tracing::{debug, info};
use uuid::Uuid;

use crate::session::Session;
Expand Down Expand Up @@ -303,22 +303,22 @@ impl SparkConnectService for DaftSparkConnectService {
Ok(schema) => schema,
Err(e) => {
return invalid_argument_err!(
"Failed to translate relation to schema: {e}"
"Failed to translate relation to schema: {e:?}"
);
}
};

let schema = analyze_plan_response::DdlParse {
parsed: Some(result),
let schema = analyze_plan_response::Schema {
schema: Some(result),
};

let response = AnalyzePlanResponse {
session_id,
server_side_session_id: String::new(),
result: Some(analyze_plan_response::Result::DdlParse(schema)),
result: Some(analyze_plan_response::Result::Schema(schema)),
};

println!("response: {response:#?}");
debug!("response: {response:#?}");

Ok(Response::new(response))
}
Expand Down
6 changes: 6 additions & 0 deletions src/daft-connect/src/translation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
//! Translation between Spark Connect and Daft
mod datatype;
mod expr;
mod literal;
mod logical_plan;
mod schema;

pub use datatype::to_spark_datatype;
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
pub use logical_plan::to_logical_plan;
pub use schema::relation_to_schema;
114 changes: 114 additions & 0 deletions src/daft-connect/src/translation/datatype.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use daft_schema::dtype::DataType;
use spark_connect::data_type::Kind;
use tracing::warn;

pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType {
match datatype {
DataType::Null => spark_connect::DataType {
kind: Some(Kind::Null(spark_connect::data_type::Null {
type_variation_reference: 0,
})),
},
DataType::Boolean => spark_connect::DataType {
kind: Some(Kind::Boolean(spark_connect::data_type::Boolean {
type_variation_reference: 0,
})),
},
DataType::Int8 => spark_connect::DataType {
kind: Some(Kind::Byte(spark_connect::data_type::Byte {
type_variation_reference: 0,
})),
},
DataType::Int16 => spark_connect::DataType {
kind: Some(Kind::Short(spark_connect::data_type::Short {
type_variation_reference: 0,
})),
},
DataType::Int32 => spark_connect::DataType {
kind: Some(Kind::Integer(spark_connect::data_type::Integer {
type_variation_reference: 0,
})),
},
DataType::Int64 => spark_connect::DataType {
kind: Some(Kind::Long(spark_connect::data_type::Long {
type_variation_reference: 0,
})),
},
DataType::UInt8 => spark_connect::DataType {
kind: Some(Kind::Byte(spark_connect::data_type::Byte {
type_variation_reference: 0,
})),
},
DataType::UInt16 => spark_connect::DataType {
kind: Some(Kind::Short(spark_connect::data_type::Short {
type_variation_reference: 0,
})),
},
DataType::UInt32 => spark_connect::DataType {
kind: Some(Kind::Integer(spark_connect::data_type::Integer {
type_variation_reference: 0,
})),
},
DataType::UInt64 => spark_connect::DataType {
kind: Some(Kind::Long(spark_connect::data_type::Long {
type_variation_reference: 0,
})),
},
DataType::Float32 => spark_connect::DataType {
kind: Some(Kind::Float(spark_connect::data_type::Float {
type_variation_reference: 0,
})),
},
DataType::Float64 => spark_connect::DataType {
kind: Some(Kind::Double(spark_connect::data_type::Double {
type_variation_reference: 0,
})),
},
DataType::Decimal128(precision, scale) => spark_connect::DataType {
kind: Some(Kind::Decimal(spark_connect::data_type::Decimal {
scale: Some(*scale as i32),
precision: Some(*precision as i32),
type_variation_reference: 0,
})),
},
DataType::Timestamp(unit, _) => {
warn!("Ignoring time unit {unit:?} for timestamp type");
spark_connect::DataType {
kind: Some(Kind::Timestamp(spark_connect::data_type::Timestamp {
type_variation_reference: 0,
})),
}
}
DataType::Date => spark_connect::DataType {
kind: Some(Kind::Date(spark_connect::data_type::Date {
type_variation_reference: 0,
})),
},
DataType::Binary => spark_connect::DataType {
kind: Some(Kind::Binary(spark_connect::data_type::Binary {
type_variation_reference: 0,
})),
},
DataType::Utf8 => spark_connect::DataType {
kind: Some(Kind::String(spark_connect::data_type::String {
type_variation_reference: 0,
collation: String::new(), // todo(correctness): is this correct?
})),
},
DataType::Struct(fields) => spark_connect::DataType {
kind: Some(Kind::Struct(spark_connect::data_type::Struct {
fields: fields
.iter()
.map(|f| spark_connect::data_type::StructField {
name: f.name.clone(),
data_type: Some(to_spark_datatype(&f.dtype)),
nullable: true, // todo(correctness): is this correct?
metadata: None,
})
.collect(),
type_variation_reference: 0,
})),
},
_ => unimplemented!("Unsupported datatype: {datatype:?}"),
}
}
105 changes: 105 additions & 0 deletions src/daft-connect/src/translation/expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use std::sync::Arc;

use eyre::{bail, Context};
use spark_connect::{expression as spark_expr, Expression};
use tracing::warn;
use unresolved_function::unresolved_to_daft_expr;

use crate::translation::to_daft_literal;

mod unresolved_function;

pub fn to_daft_expr(expression: Expression) -> eyre::Result<daft_dsl::ExprRef> {
if let Some(common) = expression.common {
warn!("Ignoring common metadata for relation: {common:?}; not yet implemented");
};

let Some(expr) = expression.expr_type else {
bail!("Expression is required");
};

match expr {
spark_expr::ExprType::Literal(l) => to_daft_literal(l),
spark_expr::ExprType::UnresolvedAttribute(attr) => {
let spark_expr::UnresolvedAttribute {
unparsed_identifier,
plan_id,
is_metadata_column,
} = attr;

if let Some(plan_id) = plan_id {
warn!("Ignoring plan_id {plan_id} for attribute expressions; not yet implemented");
}

if let Some(is_metadata_column) = is_metadata_column {
warn!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented");
}

Ok(daft_dsl::col(unparsed_identifier))
}
spark_expr::ExprType::UnresolvedFunction(f) => {
unresolved_to_daft_expr(f).wrap_err("Failed to handle unresolved function")
}
spark_expr::ExprType::ExpressionString(_) => bail!("Expression string not yet supported"),
spark_expr::ExprType::UnresolvedStar(_) => {
bail!("Unresolved star expressions not yet supported")
}
spark_expr::ExprType::Alias(alias) => {
let spark_expr::Alias {
expr,
name,
metadata,
} = *alias;

let Some(expr) = expr else {
bail!("Alias expr is required");
};

let [name] = name.as_slice() else {
bail!("Alias name is required and currently only works with a single string; got {name:?}");
};

if let Some(metadata) = metadata {
bail!("Alias metadata is not yet supported; got {metadata:?}");
}

let child = to_daft_expr(*expr)?;

let name = Arc::from(name.as_str());

Ok(child.alias(name))
}
spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"),
spark_expr::ExprType::UnresolvedRegex(_) => {
bail!("Unresolved regex expressions not yet supported")
}
spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"),
spark_expr::ExprType::LambdaFunction(_) => {
bail!("Lambda function expressions not yet supported")
}
spark_expr::ExprType::Window(_) => bail!("Window expressions not yet supported"),
spark_expr::ExprType::UnresolvedExtractValue(_) => {
bail!("Unresolved extract value expressions not yet supported")
}
spark_expr::ExprType::UpdateFields(_) => {
bail!("Update fields expressions not yet supported")
}
spark_expr::ExprType::UnresolvedNamedLambdaVariable(_) => {
bail!("Unresolved named lambda variable expressions not yet supported")
}
spark_expr::ExprType::CommonInlineUserDefinedFunction(_) => {
bail!("Common inline user defined function expressions not yet supported")
}
spark_expr::ExprType::CallFunction(_) => {
bail!("Call function expressions not yet supported")
}
spark_expr::ExprType::NamedArgumentExpression(_) => {
bail!("Named argument expressions not yet supported")
}
spark_expr::ExprType::MergeAction(_) => bail!("Merge action expressions not yet supported"),
spark_expr::ExprType::TypedAggregateExpression(_) => {
bail!("Typed aggregate expressions not yet supported")
}
spark_expr::ExprType::Extension(_) => bail!("Extension expressions not yet supported"),
}
}
45 changes: 45 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use daft_core::count_mode::CountMode;
use daft_schema::dtype::DataType;
use eyre::{bail, Context};
use spark_connect::expression::UnresolvedFunction;

use crate::translation::to_daft_expr;

pub fn unresolved_to_daft_expr(f: UnresolvedFunction) -> eyre::Result<daft_dsl::ExprRef> {
let UnresolvedFunction {
function_name,
arguments,
is_distinct,
is_user_defined_function,
} = f;

let arguments: Vec<_> = arguments.into_iter().map(to_daft_expr).try_collect()?;

if is_distinct {
bail!("Distinct not yet supported");
}

if is_user_defined_function {
bail!("User-defined functions not yet supported");
}

match function_name.as_str() {
"count" => handle_count(arguments).wrap_err("Failed to handle count function"),
n => bail!("Unresolved function {n} not yet supported"),
}
}

pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");
}
};

let [arg] = arguments;

let count = arg.count(CountMode::All);

Ok(count)
}
Loading

0 comments on commit 0682b6d

Please sign in to comment.