-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] connect: add alias support (#3342)
- Loading branch information
1 parent
2c0f3cd
commit 88edc4a
Showing
16 changed files
with
556 additions
and
100 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
//! Translation between Spark Connect and Daft | ||
mod datatype; | ||
mod expr; | ||
mod literal; | ||
mod logical_plan; | ||
mod schema; | ||
|
||
pub use datatype::to_spark_datatype; | ||
pub use expr::to_daft_expr; | ||
pub use literal::to_daft_literal; | ||
pub use logical_plan::to_logical_plan; | ||
pub use schema::relation_to_schema; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
use daft_schema::dtype::DataType; | ||
use spark_connect::data_type::Kind; | ||
use tracing::warn; | ||
|
||
pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { | ||
match datatype { | ||
DataType::Null => spark_connect::DataType { | ||
kind: Some(Kind::Null(spark_connect::data_type::Null { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Boolean => spark_connect::DataType { | ||
kind: Some(Kind::Boolean(spark_connect::data_type::Boolean { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Int8 => spark_connect::DataType { | ||
kind: Some(Kind::Byte(spark_connect::data_type::Byte { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Int16 => spark_connect::DataType { | ||
kind: Some(Kind::Short(spark_connect::data_type::Short { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Int32 => spark_connect::DataType { | ||
kind: Some(Kind::Integer(spark_connect::data_type::Integer { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Int64 => spark_connect::DataType { | ||
kind: Some(Kind::Long(spark_connect::data_type::Long { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::UInt8 => spark_connect::DataType { | ||
kind: Some(Kind::Byte(spark_connect::data_type::Byte { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::UInt16 => spark_connect::DataType { | ||
kind: Some(Kind::Short(spark_connect::data_type::Short { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::UInt32 => spark_connect::DataType { | ||
kind: Some(Kind::Integer(spark_connect::data_type::Integer { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::UInt64 => spark_connect::DataType { | ||
kind: Some(Kind::Long(spark_connect::data_type::Long { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Float32 => spark_connect::DataType { | ||
kind: Some(Kind::Float(spark_connect::data_type::Float { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Float64 => spark_connect::DataType { | ||
kind: Some(Kind::Double(spark_connect::data_type::Double { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Decimal128(precision, scale) => spark_connect::DataType { | ||
kind: Some(Kind::Decimal(spark_connect::data_type::Decimal { | ||
scale: Some(*scale as i32), | ||
precision: Some(*precision as i32), | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Timestamp(unit, _) => { | ||
warn!("Ignoring time unit {unit:?} for timestamp type"); | ||
spark_connect::DataType { | ||
kind: Some(Kind::Timestamp(spark_connect::data_type::Timestamp { | ||
type_variation_reference: 0, | ||
})), | ||
} | ||
} | ||
DataType::Date => spark_connect::DataType { | ||
kind: Some(Kind::Date(spark_connect::data_type::Date { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Binary => spark_connect::DataType { | ||
kind: Some(Kind::Binary(spark_connect::data_type::Binary { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Utf8 => spark_connect::DataType { | ||
kind: Some(Kind::String(spark_connect::data_type::String { | ||
type_variation_reference: 0, | ||
collation: String::new(), // todo(correctness): is this correct? | ||
})), | ||
}, | ||
DataType::Struct(fields) => spark_connect::DataType { | ||
kind: Some(Kind::Struct(spark_connect::data_type::Struct { | ||
fields: fields | ||
.iter() | ||
.map(|f| spark_connect::data_type::StructField { | ||
name: f.name.clone(), | ||
data_type: Some(to_spark_datatype(&f.dtype)), | ||
nullable: true, // todo(correctness): is this correct? | ||
metadata: None, | ||
}) | ||
.collect(), | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
_ => unimplemented!("Unsupported datatype: {datatype:?}"), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
use std::sync::Arc; | ||
|
||
use eyre::{bail, Context}; | ||
use spark_connect::{expression as spark_expr, Expression}; | ||
use tracing::warn; | ||
use unresolved_function::unresolved_to_daft_expr; | ||
|
||
use crate::translation::to_daft_literal; | ||
|
||
mod unresolved_function; | ||
|
||
pub fn to_daft_expr(expression: Expression) -> eyre::Result<daft_dsl::ExprRef> { | ||
if let Some(common) = expression.common { | ||
warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); | ||
}; | ||
|
||
let Some(expr) = expression.expr_type else { | ||
bail!("Expression is required"); | ||
}; | ||
|
||
match expr { | ||
spark_expr::ExprType::Literal(l) => to_daft_literal(l), | ||
spark_expr::ExprType::UnresolvedAttribute(attr) => { | ||
let spark_expr::UnresolvedAttribute { | ||
unparsed_identifier, | ||
plan_id, | ||
is_metadata_column, | ||
} = attr; | ||
|
||
if let Some(plan_id) = plan_id { | ||
warn!("Ignoring plan_id {plan_id} for attribute expressions; not yet implemented"); | ||
} | ||
|
||
if let Some(is_metadata_column) = is_metadata_column { | ||
warn!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented"); | ||
} | ||
|
||
Ok(daft_dsl::col(unparsed_identifier)) | ||
} | ||
spark_expr::ExprType::UnresolvedFunction(f) => { | ||
unresolved_to_daft_expr(f).wrap_err("Failed to handle unresolved function") | ||
} | ||
spark_expr::ExprType::ExpressionString(_) => bail!("Expression string not yet supported"), | ||
spark_expr::ExprType::UnresolvedStar(_) => { | ||
bail!("Unresolved star expressions not yet supported") | ||
} | ||
spark_expr::ExprType::Alias(alias) => { | ||
let spark_expr::Alias { | ||
expr, | ||
name, | ||
metadata, | ||
} = *alias; | ||
|
||
let Some(expr) = expr else { | ||
bail!("Alias expr is required"); | ||
}; | ||
|
||
let [name] = name.as_slice() else { | ||
bail!("Alias name is required and currently only works with a single string; got {name:?}"); | ||
}; | ||
|
||
if let Some(metadata) = metadata { | ||
bail!("Alias metadata is not yet supported; got {metadata:?}"); | ||
} | ||
|
||
let child = to_daft_expr(*expr)?; | ||
|
||
let name = Arc::from(name.as_str()); | ||
|
||
Ok(child.alias(name)) | ||
} | ||
spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"), | ||
spark_expr::ExprType::UnresolvedRegex(_) => { | ||
bail!("Unresolved regex expressions not yet supported") | ||
} | ||
spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"), | ||
spark_expr::ExprType::LambdaFunction(_) => { | ||
bail!("Lambda function expressions not yet supported") | ||
} | ||
spark_expr::ExprType::Window(_) => bail!("Window expressions not yet supported"), | ||
spark_expr::ExprType::UnresolvedExtractValue(_) => { | ||
bail!("Unresolved extract value expressions not yet supported") | ||
} | ||
spark_expr::ExprType::UpdateFields(_) => { | ||
bail!("Update fields expressions not yet supported") | ||
} | ||
spark_expr::ExprType::UnresolvedNamedLambdaVariable(_) => { | ||
bail!("Unresolved named lambda variable expressions not yet supported") | ||
} | ||
spark_expr::ExprType::CommonInlineUserDefinedFunction(_) => { | ||
bail!("Common inline user defined function expressions not yet supported") | ||
} | ||
spark_expr::ExprType::CallFunction(_) => { | ||
bail!("Call function expressions not yet supported") | ||
} | ||
spark_expr::ExprType::NamedArgumentExpression(_) => { | ||
bail!("Named argument expressions not yet supported") | ||
} | ||
spark_expr::ExprType::MergeAction(_) => bail!("Merge action expressions not yet supported"), | ||
spark_expr::ExprType::TypedAggregateExpression(_) => { | ||
bail!("Typed aggregate expressions not yet supported") | ||
} | ||
spark_expr::ExprType::Extension(_) => bail!("Extension expressions not yet supported"), | ||
} | ||
} |
44 changes: 44 additions & 0 deletions
44
src/daft-connect/src/translation/expr/unresolved_function.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
use daft_core::count_mode::CountMode; | ||
use eyre::{bail, Context}; | ||
use spark_connect::expression::UnresolvedFunction; | ||
|
||
use crate::translation::to_daft_expr; | ||
|
||
pub fn unresolved_to_daft_expr(f: UnresolvedFunction) -> eyre::Result<daft_dsl::ExprRef> { | ||
let UnresolvedFunction { | ||
function_name, | ||
arguments, | ||
is_distinct, | ||
is_user_defined_function, | ||
} = f; | ||
|
||
let arguments: Vec<_> = arguments.into_iter().map(to_daft_expr).try_collect()?; | ||
|
||
if is_distinct { | ||
bail!("Distinct not yet supported"); | ||
} | ||
|
||
if is_user_defined_function { | ||
bail!("User-defined functions not yet supported"); | ||
} | ||
|
||
match function_name.as_str() { | ||
"count" => handle_count(arguments).wrap_err("Failed to handle count function"), | ||
n => bail!("Unresolved function {n} not yet supported"), | ||
} | ||
} | ||
|
||
pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> { | ||
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { | ||
Ok(arguments) => arguments, | ||
Err(arguments) => { | ||
bail!("requires exactly one argument; got {arguments:?}"); | ||
} | ||
}; | ||
|
||
let [arg] = arguments; | ||
|
||
let count = arg.count(CountMode::All); | ||
|
||
Ok(count) | ||
} |
Oops, something went wrong.