Skip to content

Commit

Permalink
fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Oct 14, 2024
1 parent 1fed0a4 commit 039483d
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 67 deletions.
2 changes: 1 addition & 1 deletion src/daft-connect/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl Session {
Ok(response)
}

pub fn is_modifiable(&self, operation: IsModifiable) -> Result<ConfigResponse, Status> {
pub fn is_modifiable(&self, _operation: IsModifiable) -> Result<ConfigResponse, Status> {
let response = self.config_response();

let span = tracing::info_span!("is_modifiable", session_id = %self.id);
Expand Down
54 changes: 32 additions & 22 deletions src/daft-connect/src/convert.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
use std::collections::HashSet;
use std::sync::Arc;
use crate::spark_connect::relation::RelType;
use crate::spark_connect::{Filter, Read, Relation, ShowString, WithColumns};
use std::{collections::HashSet, sync::Arc};

use anyhow::{bail, ensure, Context};
use daft_plan::{LogicalPlanBuilder, ParquetScanBuilder};
use crate::spark_connect;
use crate::spark_connect::expression::Alias;
use crate::spark_connect::read::{DataSource, ReadType};

mod expr;
use crate::spark_connect::{
expression::Alias,
read::{DataSource, ReadType},
relation::RelType,
Filter, Read, Relation, ShowString, WithColumns,
};

mod expr;

// todo: a way to do something like tracing scopes but with errors?
pub fn to_logical_plan(plan: Relation) -> anyhow::Result<LogicalPlanBuilder> {
let result = match plan.rel_type.context("rel_type is None")? {
RelType::ShowString(show_string) => {
let ShowString {
input,
num_rows,
truncate,
vertical,
input, num_rows, ..
} = *show_string;
// todo: support more truncate options
let input = *input.context("input is None")?;

let builder = to_logical_plan(input)?;

let num_rows = i64::from(num_rows);
builder
.limit(num_rows, false)?
.add_show_string_column()?
builder.limit(num_rows, false)?.add_show_string_column()?
}
RelType::Filter(filter) => {
let Filter { input, condition } = *filter;
Expand All @@ -49,10 +46,15 @@ pub fn to_logical_plan(plan: Relation) -> anyhow::Result<LogicalPlanBuilder> {
let input_plan = to_logical_plan(*input)?;

let mut new_exprs = Vec::new();
let mut existing_columns: HashSet<_> = input_plan.schema().names().into_iter().collect();
let mut existing_columns: HashSet<_> =
input_plan.schema().names().into_iter().collect();

for alias in aliases {
let Alias { expr, name, metadata } = alias;
let Alias {
expr,
name,
metadata,
} = alias;

let [name] = name.as_slice() else {
bail!("Alias name must have exactly one element");
Expand All @@ -69,7 +71,7 @@ pub fn to_logical_plan(plan: Relation) -> anyhow::Result<LogicalPlanBuilder> {

// todo: test
new_exprs.push(expr.alias(name));

if existing_columns.contains(name) {
// Replace existing column
existing_columns.remove(name);
Expand All @@ -83,24 +85,32 @@ pub fn to_logical_plan(plan: Relation) -> anyhow::Result<LogicalPlanBuilder> {

input_plan.select(new_exprs)?
}
RelType::Read(Read { is_streaming, read_type }) => {
RelType::Read(Read {
is_streaming,
read_type,
}) => {
ensure!(!is_streaming, "Streaming reads are not yet supported");
let read_type = read_type.context("read_type is None")?;

match read_type {
ReadType::NamedTable(_) => {
bail!("Named tables are not yet supported");
}
ReadType::DataSource(DataSource { format, schema, options, paths, predicates }) => {
ReadType::DataSource(DataSource {
format,
schema,
options,
paths,
predicates,
}) => {
let format = format.context("format is None")?;
let schema = schema.context("schema is None")?;

ensure!(options.is_empty(), "Options are not yet supported");

ensure!(predicates.is_empty(), "Predicates are not yet supported");

ParquetScanBuilder::new(paths)
.finish()?
ParquetScanBuilder::new(paths).finish()?
}
}
}
Expand Down
56 changes: 22 additions & 34 deletions src/daft-connect/src/convert/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@ use crate::{

pub fn to_daft_expr(expr: spark_connect::Expression) -> anyhow::Result<DaftExpr> {
match expr.expr_type {
Some(expression::ExprType::Literal(lit)) => {
// Convert Spark literal to Daft literal
Ok(DaftExpr::Literal(convert_literal(lit)?))
}
Some(expression::ExprType::Literal(lit)) => Ok(DaftExpr::Literal(convert_literal(lit)?)),

Some(expression::ExprType::UnresolvedAttribute(attr)) => {
// Convert unresolved attribute to Daft column reference
Ok(DaftExpr::Column(attr.unparsed_identifier.into()))
}

Expand All @@ -33,6 +29,12 @@ pub fn to_daft_expr(expr: spark_connect::Expression) -> anyhow::Result<DaftExpr>
// Convert alias
let expr = to_daft_expr(expr)?;

if let Some(metadata) = metadata
&& !metadata.is_empty()
{
bail!("Metadata is not yet supported");
}

// ignore metadata for now

let [name] = name.as_slice() else {
Expand All @@ -48,6 +50,12 @@ pub fn to_daft_expr(expr: spark_connect::Expression) -> anyhow::Result<DaftExpr>
is_distinct,
is_user_defined_function,
})) => {
ensure!(!is_distinct, "Distinct is not yet supported");
ensure!(
!is_user_defined_function,
"User-defined functions are not yet supported"
);

let op = function_name.as_str();
match op {
">" | "<" | "<=" | ">=" | "+" | "-" | "*" | "/" => {
Expand Down Expand Up @@ -81,19 +89,6 @@ pub fn to_daft_expr(expr: spark_connect::Expression) -> anyhow::Result<DaftExpr>
}
}

// Some(expression::ExprType::BinaryComparison(cmp)) => {
// // Convert binary comparison
// let left = to_daft_expr(*cmp.left)?;
// let right = to_daft_expr(*cmp.right)?;
// let op = convert_comparison_op(cmp.comparison_type)?;
//
// Ok(DaftExpr::BinaryOp {
// left: Box::new(left),
// op,
// right: Box::new(right),
// })
// }

// Handle other expression types...
_ => Err(anyhow::anyhow!("Unsupported expression type")),
}
Expand All @@ -114,25 +109,18 @@ fn convert_literal(lit: expression::Literal) -> anyhow::Result<daft_dsl::Literal
LiteralType::Long(input) => daft_dsl::LiteralValue::Int64(input),
LiteralType::Float(input) => daft_dsl::LiteralValue::Float64(f64::from(input)),
LiteralType::Double(input) => daft_dsl::LiteralValue::Float64(input),
LiteralType::Decimal(input) => unimplemented!(),
LiteralType::String(input) => daft_dsl::LiteralValue::Utf8(input),
LiteralType::Date(input) => daft_dsl::LiteralValue::Date(input),
LiteralType::Timestamp(input) => unimplemented!(),
LiteralType::TimestampNtz(input) => unimplemented!(),
LiteralType::CalendarInterval(input) => unimplemented!(),
LiteralType::YearMonthInterval(input) => unimplemented!(),
LiteralType::DayTimeInterval(input) => unimplemented!(),
LiteralType::Array(_) | LiteralType::Map(_) | LiteralType::Struct(_) => todo!(),
LiteralType::Decimal(_)
| LiteralType::Timestamp(_)
| LiteralType::TimestampNtz(_)
| LiteralType::CalendarInterval(_)
| LiteralType::YearMonthInterval(_)
| LiteralType::DayTimeInterval(_)
| LiteralType::Array(_)
| LiteralType::Map(_)
| LiteralType::Struct(_) => bail!("unimplemented"),
};

Ok(result)
}

fn convert_function_name(name: &str) -> anyhow::Result<daft_dsl::functions::ScalarFunction> {
// Map Spark function names to Daft equivalents
todo!()
}

// fn convert_comparison_op(op: i32) -> anyhow::Result<BinaryOperator> {
// // Map Spark comparison types to Daft binary operators
// }
2 changes: 0 additions & 2 deletions src/daft-connect/src/convert/tests.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::command::execute_plan;
use crate::convert::to_logical_plan;
use crate::spark_connect::plan::OpType::Root;
use crate::spark_connect::{Expression, Filter, Read, Relation, RelationCommon, ShowString, WithColumns};
use crate::spark_connect::expression::{Alias, ExprType, Literal, UnresolvedAttribute, UnresolvedFunction};
use crate::spark_connect::expression::literal::LiteralType;
Expand Down
10 changes: 3 additions & 7 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
#![feature(iterator_try_collect)]
#![feature(let_chains)]
#![expect(clippy::derive_partial_eq_without_eq, reason = "prost does not properly derive Eq")]

use tonic::{Request, Response, Status};
use futures::stream;
use spark_connect::spark_connect_service_server::SparkConnectService;
use spark_connect::{
ExecutePlanRequest, ExecutePlanResponse, AnalyzePlanRequest, AnalyzePlanResponse, ConfigRequest,
ConfigResponse, AddArtifactsRequest, AddArtifactsResponse, ArtifactStatusesRequest, ArtifactStatusesResponse,
InterruptRequest, InterruptResponse, ReattachExecuteRequest, ReleaseExecuteRequest, ReleaseExecuteResponse,
ReleaseSessionRequest, ReleaseSessionResponse, FetchErrorDetailsRequest, FetchErrorDetailsResponse,
};
use std::sync::{Arc, Mutex};
use uuid::Uuid;
use std::collections::{BTreeMap, HashMap};
use dashmap::{DashMap, Entry};
use std::collections::BTreeMap;
use dashmap::DashMap;
use tracing::info;
use crate::spark_connect::analyze_plan_response::TreeString;
use crate::spark_connect::command::CommandType;
use crate::spark_connect::config_request::{Operation, Set};
use crate::spark_connect::KeyValue;

pub mod spark_connect {
tonic::include_proto!("spark.connect");
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ impl LogicalPlanBuilder {
/// use daft_dsl::{col, lit};
/// use daft_schema::dtype::DataType;
///
/// let builder = LogicalPlanBuilder::default(); // todo: can we replace this correctly?
/// let builder: LogicalPlanBuilder = todo!();
///
/// // Select existing columns
/// let result = builder.select(vec![col("name"), col("age")]);
Expand Down

0 comments on commit 039483d

Please sign in to comment.