Skip to content

Commit

Permalink
[FEAT] connect: support basic column operations
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 4, 2024
1 parent 83470e0 commit a0165a6
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#![feature(iter_from_coroutine)]
#![feature(stmt_expr_attributes)]
#![feature(try_trait_v2_residual)]
#![deny(clippy::print_stdout)]

use dashmap::DashMap;
use eyre::Context;
Expand Down
2 changes: 1 addition & 1 deletion src/daft-connect/src/translation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod literal;
mod logical_plan;
mod schema;

pub use datatype::to_spark_datatype;
pub use datatype::{to_daft_datatype, to_spark_datatype};
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
pub use logical_plan::to_logical_plan;
Expand Down
151 changes: 150 additions & 1 deletion src/daft-connect/src/translation/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use daft_schema::dtype::DataType;
use daft_schema::{dtype::DataType, field::Field, time_unit::TimeUnit};
use eyre::{bail, ensure, WrapErr};
use spark_connect::data_type::Kind;
use tracing::warn;

Expand Down Expand Up @@ -112,3 +113,151 @@ pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType {
_ => unimplemented!("Unsupported datatype: {datatype:?}"),
}
}

// todo(test): add tests for this esp in Python
pub fn to_daft_datatype(datatype: &spark_connect::DataType) -> eyre::Result<DataType> {
let Some(kind) = &datatype.kind else {
bail!("Datatype is required");

Check warning on line 120 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L120

Added line #L120 was not covered by tests
};

let type_variation_err = "Custom type variation reference not supported";

match kind {
Kind::Null(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Null)

Check warning on line 128 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L126-L128

Added lines #L126 - L128 were not covered by tests
}
Kind::Binary(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Binary)

Check warning on line 132 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L130-L132

Added lines #L130 - L132 were not covered by tests
}
Kind::Boolean(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Boolean)

Check warning on line 136 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L134-L136

Added lines #L134 - L136 were not covered by tests
}
Kind::Byte(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int8)

Check warning on line 140 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L138-L140

Added lines #L138 - L140 were not covered by tests
}
Kind::Short(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int16)

Check warning on line 144 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L142-L144

Added lines #L142 - L144 were not covered by tests
}
Kind::Integer(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int32)

Check warning on line 148 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L146-L148

Added lines #L146 - L148 were not covered by tests
}
Kind::Long(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int64)

Check warning on line 152 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L150-L152

Added lines #L150 - L152 were not covered by tests
}
Kind::Float(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Float32)

Check warning on line 156 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L154-L156

Added lines #L154 - L156 were not covered by tests
}
Kind::Double(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Float64)

Check warning on line 160 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L158-L160

Added lines #L158 - L160 were not covered by tests
}
Kind::Decimal(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

Check warning on line 163 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L162-L163

Added lines #L162 - L163 were not covered by tests

let Some(precision) = value.precision else {
bail!("Decimal precision is required");

Check warning on line 166 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L165-L166

Added lines #L165 - L166 were not covered by tests
};

let Some(scale) = value.scale else {
bail!("Decimal scale is required");

Check warning on line 170 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L169-L170

Added lines #L169 - L170 were not covered by tests
};

let precision = usize::try_from(precision)
.wrap_err("Decimal precision must be a non-negative integer")?;

Check warning on line 174 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L173-L174

Added lines #L173 - L174 were not covered by tests

let scale =
usize::try_from(scale).wrap_err("Decimal scale must be a non-negative integer")?;

Check warning on line 177 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L176-L177

Added lines #L176 - L177 were not covered by tests

Ok(DataType::Decimal128(precision, scale))

Check warning on line 179 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L179

Added line #L179 was not covered by tests
}
Kind::String(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)
}
Kind::Char(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)

Check warning on line 187 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L185-L187

Added lines #L185 - L187 were not covered by tests
}
Kind::VarChar(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)

Check warning on line 191 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L189-L191

Added lines #L189 - L191 were not covered by tests
}
Kind::Date(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Date)

Check warning on line 195 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L193-L195

Added lines #L193 - L195 were not covered by tests
}
Kind::Timestamp(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

Check warning on line 198 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L197-L198

Added lines #L197 - L198 were not covered by tests
// Using microseconds precision with no timezone info matches Spark's behavior.
// Spark handles timezones at the session level rather than in the type itself.
// See: https://www.databricks.com/blog/2020/07/22/a-comprehensive-look-at-dates-and-timestamps-in-apache-spark-3-0.html
Ok(DataType::Timestamp(TimeUnit::Microseconds, None))

Check warning on line 202 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L202

Added line #L202 was not covered by tests
}
Kind::TimestampNtz(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Timestamp(TimeUnit::Microseconds, None))

Check warning on line 206 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L204-L206

Added lines #L204 - L206 were not covered by tests
}
Kind::CalendarInterval(_) => bail!("Calendar interval type not supported"),
Kind::YearMonthInterval(_) => bail!("Year-month interval type not supported"),
Kind::DayTimeInterval(_) => bail!("Day-time interval type not supported"),
Kind::Array(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let element_type = to_daft_datatype(
value
.element_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Array element type is required"))?,
)?;
Ok(DataType::List(Box::new(element_type)))

Check warning on line 219 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L208-L219

Added lines #L208 - L219 were not covered by tests
}
Kind::Struct(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let fields = value
.fields
.iter()
.map(|f| {
let field_type = to_daft_datatype(
f.data_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Struct field type is required"))?,
)?;
Ok(Field::new(&f.name, field_type))
})
.collect::<eyre::Result<Vec<_>>>()?;
Ok(DataType::Struct(fields))

Check warning on line 235 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L221-L235

Added lines #L221 - L235 were not covered by tests
}
Kind::Map(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let key_type = to_daft_datatype(
value
.key_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Map key type is required"))?,
)?;
let value_type = to_daft_datatype(
value
.value_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Map value type is required"))?,
)?;

Check warning on line 250 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L237-L250

Added lines #L237 - L250 were not covered by tests

let map = DataType::Map {
key: Box::new(key_type),
value: Box::new(value_type),
};

Ok(map)

Check warning on line 257 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L252-L257

Added lines #L252 - L257 were not covered by tests
}
Kind::Variant(_) => bail!("Variant type not supported"),
Kind::Udt(_) => bail!("User-defined type not supported"),
Kind::Unparsed(_) => bail!("Unparsed type not supported"),

Check warning on line 261 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L259-L261

Added lines #L259 - L261 were not covered by tests
}
}
67 changes: 63 additions & 4 deletions src/daft-connect/src/translation/expr.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use std::sync::Arc;

use eyre::{bail, Context};
use spark_connect::{expression as spark_expr, Expression};
use spark_connect::{
expression as spark_expr,
expression::{
cast::{CastToType, EvalMode},
sort_order::{NullOrdering, SortDirection},
},
Expression,
};
use tracing::warn;
use unresolved_function::unresolved_to_daft_expr;

use crate::translation::to_daft_literal;
use crate::translation::{to_daft_datatype, to_daft_literal};

mod unresolved_function;

Expand Down Expand Up @@ -69,11 +76,63 @@ pub fn to_daft_expr(expression: &Expression) -> eyre::Result<daft_dsl::ExprRef>

Ok(child.alias(name))
}
spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"),
spark_expr::ExprType::Cast(c) => {
// Cast { expr: Some(Expression { common: None, expr_type: Some(UnresolvedAttribute(UnresolvedAttribute { unparsed_identifier: "id", plan_id: None, is_metadata_column: None })) }), eval_mode: Unspecified, cast_to_type: Some(Type(DataType { kind: Some(String(String { type_variation_reference: 0, collation: "" })) })) }
// thread 'tokio-runtime-worker' panicked at src/daft-connect/src/trans
let spark_expr::Cast {
expr,
eval_mode,
cast_to_type,
} = &**c;

let Some(expr) = expr else {
bail!("Cast expression is required");

Check warning on line 89 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L89

Added line #L89 was not covered by tests
};

let expr = to_daft_expr(expr)?;

let Some(cast_to_type) = cast_to_type else {
bail!("Cast to type is required");

Check warning on line 95 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L95

Added line #L95 was not covered by tests
};

let data_type = match cast_to_type {
CastToType::Type(kind) => to_daft_datatype(kind).wrap_err_with(|| {
format!("Failed to convert spark datatype to daft datatype: {kind:?}")

Check warning on line 100 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L100

Added line #L100 was not covered by tests
})?,
CastToType::TypeStr(s) => {
bail!("Cast to type string not yet supported; tried to cast to {s}");

Check warning on line 103 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L102-L103

Added lines #L102 - L103 were not covered by tests
}
};

let eval_mode = EvalMode::try_from(*eval_mode)
.wrap_err_with(|| format!("Invalid cast eval mode: {eval_mode}"))?;

warn!("Ignoring cast eval mode: {eval_mode:?}");

Ok(expr.cast(&data_type))
}
spark_expr::ExprType::UnresolvedRegex(_) => {
bail!("Unresolved regex expressions not yet supported")
}
spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"),
spark_expr::ExprType::SortOrder(s) => {
let spark_expr::SortOrder {
child,
direction,
null_ordering,
} = &**s;

Check warning on line 122 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L117-L122

Added lines #L117 - L122 were not covered by tests

let Some(_child) = child else {
bail!("Sort order child is required");

Check warning on line 125 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L124-L125

Added lines #L124 - L125 were not covered by tests
};

let _sort_direction = SortDirection::try_from(*direction)
.wrap_err_with(|| format!("Invalid sort direction: {direction}"))?;

Check warning on line 129 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L128-L129

Added lines #L128 - L129 were not covered by tests

let _sort_nulls = NullOrdering::try_from(*null_ordering)
.wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?;

Check warning on line 132 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L131-L132

Added lines #L131 - L132 were not covered by tests

bail!("Sort order expressions not yet supported");

Check warning on line 134 in src/daft-connect/src/translation/expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr.rs#L134

Added line #L134 was not covered by tests
}
spark_expr::ExprType::LambdaFunction(_) => {
bail!("Lambda function expressions not yet supported")
}
Expand Down
28 changes: 28 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:

match function_name.as_str() {
"count" => handle_count(arguments).wrap_err("Failed to handle count function"),
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
n => bail!("Unresolved function {n} not yet supported"),
}
}
Expand All @@ -42,3 +44,29 @@ pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl:

Ok(count)
}

pub fn handle_isnull(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");

Check warning on line 52 in src/daft-connect/src/translation/expr/unresolved_function.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr/unresolved_function.rs#L51-L52

Added lines #L51 - L52 were not covered by tests
}
};

let [arg] = arguments;

Ok(arg.is_null())
}

pub fn handle_isnotnull(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");

Check warning on line 65 in src/daft-connect/src/translation/expr/unresolved_function.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr/unresolved_function.rs#L64-L65

Added lines #L64 - L65 were not covered by tests
}
};

let [arg] = arguments;

Ok(arg.not_null())
}
46 changes: 46 additions & 0 deletions tests/connect/test_basic_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from pyspark.sql.functions import col
from pyspark.sql.types import StringType


def test_column_alias(spark_session):
df = spark_session.range(10)
df_alias = df.select(col("id").alias("my_number"))
assert "my_number" in df_alias.columns, "alias should rename column"
assert df_alias.toPandas()["my_number"].equals(df.toPandas()["id"]), "data should be unchanged"


def test_column_cast(spark_session):
df = spark_session.range(10)
df_cast = df.select(col("id").cast(StringType()))
assert df_cast.schema.fields[0].dataType == StringType(), "cast should change data type"
assert df_cast.toPandas()["id"].dtype == "object", "cast should change pandas dtype to object/string"


def test_column_null_checks(spark_session):
df = spark_session.range(10)
df_null = df.select(col("id").isNotNull().alias("not_null"), col("id").isNull().alias("is_null"))
assert df_null.toPandas()["not_null"].iloc[0], "isNotNull should be True for non-null values"
assert not df_null.toPandas()["is_null"].iloc[0], "isNull should be False for non-null values"


def test_column_name(spark_session):
df = spark_session.range(10)
df_name = df.select(col("id").name("renamed_id"))
assert "renamed_id" in df_name.columns, "name should rename column"
assert df_name.toPandas()["renamed_id"].equals(df.toPandas()["id"]), "data should be unchanged"


# TODO: Uncomment when https://github.com/Eventual-Inc/Daft/issues/3433 is fixed
# def test_column_desc(spark_session):
# df = spark_session.range(10)
# df_attr = df.select(col("id").desc())
# assert df_attr.toPandas()["id"].iloc[0] == 9, "desc should sort in descending order"


# TODO: Add test when extract value is implemented
# def test_column_getitem(spark_session):
# df = spark_session.range(10)
# df_item = df.select(col("id")[0])
# assert df_item.toPandas()["id"].iloc[0] == 0, "getitem should return first element"

0 comments on commit a0165a6

Please sign in to comment.