Skip to content

Commit

Permalink
[FIX] (WIP) casting of arrays from daft to arrow with unsigned
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 23, 2024
1 parent 5b36bd0 commit 9ac25eb
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
48 changes: 44 additions & 4 deletions src/daft-connect/src/op/execute/root.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::{collections::HashMap, future::ready};

use common_daft_config::DaftExecutionConfig;
use daft_core::series::Series;
use daft_local_execution::NativeExecutor;
use daft_schema::{dtype::DataType, field::Field, schema::Schema};
use daft_table::Table;
use futures::stream;
use spark_connect::{ExecutePlanResponse, Relation};
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};
Expand Down Expand Up @@ -38,17 +41,54 @@ impl Session {
let mut result_stream = native_executor
.run(HashMap::new(), cfg.into(), None)?
.into_stream();

while let Some(result) = result_stream.next().await {
let result = result?;
let tables = result.get_tables()?;

for table in tables.as_slice() {
let response = context.gen_response(table)?;
if tx.send(Ok(response)).await.is_err() {
return Ok(());
// Inside the for loop over tables
let mut arrow_arrays = Vec::with_capacity(table.num_columns());
let mut column_names = Vec::with_capacity(table.num_columns());
let mut field_types = Vec::with_capacity(table.num_columns());

for i in 0..table.num_columns() {
let s = table.get_column_by_index(i)?;
let arrow_array = s.to_arrow();

let arrow_array =
daft_core::utils::arrow::cast_array_from_daft_if_needed(
arrow_array.to_boxed(),
);

// todo(correctness): logical types probably get **DESTROYED** here 💥😭😭
let daft_data_type = DataType::from(arrow_array.data_type());

// Store the actual type after potential casting
field_types.push(Field::new(s.name(), daft_data_type));
column_names.push(s.name().to_string());
arrow_arrays.push(arrow_array);
}

// Create new schema with actual types after casting
let new_schema = Schema::new(field_types)?;

println!("new schema: {:?}", new_schema);

// Convert arrays back to series
let series = arrow_arrays
.into_iter()
.zip(column_names)
.map(|(array, name)| Series::try_from((name.as_str(), array)))
.try_collect()?;

// Create table from series
let new_table = Table::new_with_size(new_schema, series, table.len())?;

let response = context.gen_response(&new_table)?;
tx.send(Ok(response)).await.unwrap();
}
}

Ok(())
};

Expand Down
13 changes: 13 additions & 0 deletions tests/connect/test_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations


def test_count(spark_session):
# Create a range using Spark
# For example, creating a range from 0 to 9
spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9

# Convert to Pandas DataFrame
count = spark_range.count()

# Verify the DataFrame has expected values
assert count == 10, "DataFrame should have 10 rows"

0 comments on commit 9ac25eb

Please sign in to comment.