diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index 1e1fac147b..0e89cd4dec 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -1,7 +1,10 @@ use std::{collections::HashMap, future::ready}; use common_daft_config::DaftExecutionConfig; +use daft_core::series::Series; use daft_local_execution::NativeExecutor; +use daft_schema::{dtype::DataType, field::Field, schema::Schema}; +use daft_table::Table; use futures::stream; use spark_connect::{ExecutePlanResponse, Relation}; use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; @@ -38,17 +41,54 @@ impl Session { let mut result_stream = native_executor .run(HashMap::new(), cfg.into(), None)? .into_stream(); - while let Some(result) = result_stream.next().await { let result = result?; let tables = result.get_tables()?; + for table in tables.as_slice() { - let response = context.gen_response(table)?; - if tx.send(Ok(response)).await.is_err() { - return Ok(()); + // Inside the for loop over tables + let mut arrow_arrays = Vec::with_capacity(table.num_columns()); + let mut column_names = Vec::with_capacity(table.num_columns()); + let mut field_types = Vec::with_capacity(table.num_columns()); + + for i in 0..table.num_columns() { + let s = table.get_column_by_index(i)?; + let arrow_array = s.to_arrow(); + + let arrow_array = + daft_core::utils::arrow::cast_array_from_daft_if_needed( + arrow_array.to_boxed(), + ); + + // todo(correctness): logical types probably get **DESTROYED** here 💥😭😭 + let daft_data_type = DataType::from(arrow_array.data_type()); + + // Store the actual type after potential casting + field_types.push(Field::new(s.name(), daft_data_type)); + column_names.push(s.name().to_string()); + arrow_arrays.push(arrow_array); } + + // Create new schema with actual types after casting + let new_schema = Schema::new(field_types)?; + + println!("new schema: {:?}", new_schema); + + // Convert arrays back to series + let series = arrow_arrays + .into_iter() + .zip(column_names) + .map(|(array, name)| Series::try_from((name.as_str(), array))) + .try_collect()?; + + // Create table from series + let new_table = Table::new_with_size(new_schema, series, table.len())?; + + let response = context.gen_response(&new_table)?; + tx.send(Ok(response)).await.unwrap(); } } + Ok(()) }; diff --git a/tests/connect/test_count.py b/tests/connect/test_count.py new file mode 100644 index 0000000000..1cb48dd47b --- /dev/null +++ b/tests/connect/test_count.py @@ -0,0 +1,13 @@ +from __future__ import annotations + + +def test_count(spark_session): + # Create a range using Spark + # For example, creating a range from 0 to 9 + spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9 + + # Convert to Pandas DataFrame + count = spark_range.count() + + # Verify the DataFrame has expected values + assert count == 10, "DataFrame should have 10 rows"