Skip to content

Commit

Permalink
[FIX] (WIP) casting of arrays from daft to arrow with unsigned
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 25, 2024
1 parent fa1d9d7 commit ae31759
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 7 deletions.
1 change: 0 additions & 1 deletion src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ impl SparkConnectService for DaftSparkConnectService {
request: Request<ExecutePlanRequest>,
) -> Result<Response<Self::ExecutePlanStream>, Status> {
let request = request.into_inner();

let session = self.get_session(&request.session_id)?;

let Some(operation) = request.operation_id else {
Expand Down
41 changes: 37 additions & 4 deletions src/daft-connect/src/op/execute/root.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::{collections::HashMap, future::ready};

use common_daft_config::DaftExecutionConfig;
use daft_core::series::Series;
use daft_local_execution::NativeExecutor;
use daft_schema::{field::Field, schema::Schema};
use daft_table::Table;
use futures::stream;
use spark_connect::{ExecutePlanResponse, Relation};
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};
Expand All @@ -10,6 +13,7 @@ use crate::{
op::execute::{ExecuteStream, PlanIds},
session::Session,
translation,
translation::to_spark_compatible_datatype,
};

impl Session {
Expand Down Expand Up @@ -38,17 +42,46 @@ impl Session {
let mut result_stream = native_executor
.run(HashMap::new(), cfg.into(), None)?
.into_stream();

while let Some(result) = result_stream.next().await {
let result = result?;
let tables = result.get_tables()?;

for table in tables.as_slice() {
let response = context.gen_response(table)?;
if tx.send(Ok(response)).await.is_err() {
return Ok(());
// Inside the for loop over tables
let mut arrow_arrays = Vec::with_capacity(table.num_columns());
let mut column_names = Vec::with_capacity(table.num_columns());
let mut field_types = Vec::with_capacity(table.num_columns());

for i in 0..table.num_columns() {
let s = table.get_column_by_index(i)?;

let daft_data_type = to_spark_compatible_datatype(s.data_type());
let s = s.cast(&daft_data_type)?;

// Store the actual type after potential casting
field_types.push(Field::new(s.name(), daft_data_type));
column_names.push(s.name().to_string());
arrow_arrays.push(s.to_arrow());
}

// Create new schema with actual types after casting
let new_schema = Schema::new(field_types)?;

// Convert arrays back to series
let series = arrow_arrays
.into_iter()
.zip(column_names)
.map(|(array, name)| Series::try_from((name.as_str(), array)))
.try_collect()?;

// Create table from series
let new_table = Table::new_with_size(new_schema, series, table.len())?;

let response = context.gen_response(&new_table)?;
tx.send(Ok(response)).await.unwrap();
}
}

Ok(())
};

Expand Down
2 changes: 1 addition & 1 deletion src/daft-connect/src/translation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod literal;
mod logical_plan;
mod schema;

pub use datatype::to_spark_datatype;
pub use datatype::{to_spark_compatible_datatype, to_spark_datatype};
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
pub use logical_plan::to_logical_plan;
Expand Down
22 changes: 21 additions & 1 deletion src/daft-connect/src/translation/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
use daft_schema::dtype::DataType;
use daft_schema::{dtype::DataType, field::Field};
use spark_connect::data_type::Kind;
use tracing::warn;

// todo: still a WIP; by no means complete
pub fn to_spark_compatible_datatype(datatype: &DataType) -> DataType {
// TL;DR unsigned integers are not supported by Spark
match datatype {
DataType::UInt8 => DataType::Int8,
DataType::UInt16 => DataType::Int16,
DataType::UInt32 => DataType::Int32,

Check warning on line 11 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L9-L11

Added lines #L9 - L11 were not covered by tests
DataType::UInt64 => DataType::Int64,
DataType::Struct(fields) => {
let fields = fields
.iter()
.map(|f| Field::new(f.name.clone(), to_spark_compatible_datatype(&f.dtype)))
.collect();

DataType::Struct(fields)

Check warning on line 19 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L13-L19

Added lines #L13 - L19 were not covered by tests
}
_ => datatype.clone(),
}
}

pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType {
match datatype {
DataType::Null => spark_connect::DataType {
Expand Down
11 changes: 11 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use daft_core::count_mode::CountMode;
use eyre::{bail, Context};
use spark_connect::expression::UnresolvedFunction;
use tracing::debug;

use crate::translation::to_daft_expr;

Expand Down Expand Up @@ -38,6 +39,16 @@ pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl:

let [arg] = arguments;

// special case to be consistent with how spark handles counting literals
// see https://github.com/Eventual-Inc/Daft/issues/3421
let count_special_case = *arg == daft_dsl::Expr::Literal(daft_dsl::LiteralValue::Int32(1));

if count_special_case {
debug!("special case for count");
let result = daft_dsl::col("*").count(CountMode::All);
return Ok(result);
}

Check warning on line 51 in src/daft-connect/src/translation/expr/unresolved_function.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr/unresolved_function.rs#L50-L51

Added lines #L50 - L51 were not covered by tests
let count = arg.count(CountMode::All);

Ok(count)
Expand Down
13 changes: 13 additions & 0 deletions tests/connect/test_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations


def test_count(spark_session):
# Create a range using Spark
# For example, creating a range from 0 to 9
spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9

# Convert to Pandas DataFrame
count = spark_range.count()

# Verify the DataFrame has expected values
assert count == 10, "DataFrame should have 10 rows"

0 comments on commit ae31759

Please sign in to comment.