From 72aa46cfd37c4ad06b3e0b641761b7b2296b56cd Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 00:26:03 -0800 Subject: [PATCH] [FIX] (WIP) casting of arrays from daft to arrow with unsigned --- src/daft-connect/src/op/execute/root.rs | 46 ++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index 1e1fac147b..8fa8bf020c 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -1,7 +1,10 @@ use std::{collections::HashMap, future::ready}; use common_daft_config::DaftExecutionConfig; +use daft_core::series::Series; use daft_local_execution::NativeExecutor; +use daft_schema::{dtype::DataType, field::Field, schema::Schema}; +use daft_table::Table; use futures::stream; use spark_connect::{ExecutePlanResponse, Relation}; use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; @@ -38,17 +41,52 @@ impl Session { let mut result_stream = native_executor .run(HashMap::new(), cfg.into(), None)? .into_stream(); - while let Some(result) = result_stream.next().await { let result = result?; let tables = result.get_tables()?; + for table in tables.as_slice() { - let response = context.gen_response(table)?; - if tx.send(Ok(response)).await.is_err() { - return Ok(()); + // Inside the for loop over tables + let mut arrow_arrays = Vec::with_capacity(table.num_columns()); + let mut column_names = Vec::with_capacity(table.num_columns()); + let mut field_types = Vec::with_capacity(table.num_columns()); + + for i in 0..table.num_columns() { + let s = table.get_column_by_index(i)?; + let arrow_array = s.to_arrow(); + + let arrow_array = + daft_core::utils::arrow::cast_array_from_daft_if_needed( + arrow_array.to_boxed(), + ); + + // todo(correctness): logical types probably get **DESTROYED** here 💥😭😭 + let daft_data_type = DataType::from(arrow_array.data_type()); + + // Store the actual type after potential casting + field_types.push(Field::new(s.name(), daft_data_type)); + column_names.push(s.name().to_string()); + arrow_arrays.push(arrow_array); } + + // Create new schema with actual types after casting + let new_schema = Schema::new(field_types)?; + + // Convert arrays back to series + let series = arrow_arrays + .into_iter() + .zip(column_names) + .map(|(array, name)| Series::try_from((name.as_str(), array))) + .try_collect()?; + + // Create table from series + let new_table = Table::new_with_size(new_schema, series, table.len())?; + + let response = context.gen_response(&new_table)?; + tx.blocking_send(Ok(response)).unwrap(); } } + Ok(()) };