diff --git a/src/daft-connect/src/translation/datatype/codec.rs b/src/daft-connect/src/translation/datatype/codec.rs index c5cf13aad6..50f2d94a02 100644 --- a/src/daft-connect/src/translation/datatype/codec.rs +++ b/src/daft-connect/src/translation/datatype/codec.rs @@ -1,6 +1,7 @@ -use eyre::{bail, ensure}; +use color_eyre::Help; +use eyre::{bail, ensure, eyre}; use serde_json::Value; -use spark_connect::data_type::{Kind, Long}; +use spark_connect::data_type::Kind; use tracing::warn; #[derive(Debug)] @@ -89,11 +90,88 @@ fn deser_helper( let kind = remove_type(&mut input)?; let result = match kind { + TypeTag::Null => Ok(Kind::Null(spark_connect::data_type::Null { + type_variation_reference: 0, + })), + TypeTag::Binary => Ok(Kind::Binary(spark_connect::data_type::Binary { + type_variation_reference: 0, + })), + TypeTag::Boolean => Ok(Kind::Boolean(spark_connect::data_type::Boolean { + type_variation_reference: 0, + })), + TypeTag::Byte => Ok(Kind::Byte(spark_connect::data_type::Byte { + type_variation_reference: 0, + })), + TypeTag::Short => Ok(Kind::Short(spark_connect::data_type::Short { + type_variation_reference: 0, + })), + TypeTag::Integer => Ok(Kind::Integer(spark_connect::data_type::Integer { + type_variation_reference: 0, + })), + TypeTag::Long => Ok(Kind::Long(spark_connect::data_type::Long { + type_variation_reference: 0, + })), + TypeTag::Float => Ok(Kind::Float(spark_connect::data_type::Float { + type_variation_reference: 0, + })), + TypeTag::Double => Ok(Kind::Double(spark_connect::data_type::Double { + type_variation_reference: 0, + })), + TypeTag::Decimal => Ok(Kind::Decimal(spark_connect::data_type::Decimal { + scale: None, + precision: None, + type_variation_reference: 0, + })), + TypeTag::String => Ok(Kind::String(spark_connect::data_type::String { + type_variation_reference: 0, + collation: String::new(), + })), + TypeTag::Char => Ok(Kind::Char(spark_connect::data_type::Char { + type_variation_reference: 0, + length: 1, + })), + TypeTag::VarChar => Ok(Kind::VarChar(spark_connect::data_type::VarChar { + type_variation_reference: 0, + length: 0, + })), + TypeTag::Date => Ok(Kind::Date(spark_connect::data_type::Date { + type_variation_reference: 0, + })), + TypeTag::Timestamp => Ok(Kind::Timestamp(spark_connect::data_type::Timestamp { + type_variation_reference: 0, + })), + TypeTag::TimestampNtz => Ok(Kind::TimestampNtz(spark_connect::data_type::TimestampNtz { + type_variation_reference: 0, + })), + TypeTag::CalendarInterval => Ok(Kind::CalendarInterval( + spark_connect::data_type::CalendarInterval { + type_variation_reference: 0, + }, + )), + TypeTag::YearMonthInterval => Ok(Kind::YearMonthInterval( + spark_connect::data_type::YearMonthInterval { + type_variation_reference: 0, + start_field: None, + end_field: None, + }, + )), + TypeTag::DayTimeInterval => Ok(Kind::DayTimeInterval( + spark_connect::data_type::DayTimeInterval { + type_variation_reference: 0, + start_field: None, + end_field: None, + }, + )), + TypeTag::Array => Err(eyre!("Array type not supported")) + .suggestion("Wait until we support arrays in Spark Connect"), TypeTag::Struct => deser_struct(input), - TypeTag::Long => Ok(Kind::Long(Long { + TypeTag::Map => Err(eyre!("Map type not supported")) + .suggestion("Wait until we support maps in Spark Connect"), + TypeTag::Variant => Ok(Kind::Variant(spark_connect::data_type::Variant { type_variation_reference: 0, })), - _ => bail!("unsupported type tag {:?}", kind), + TypeTag::Udt => bail!("UDT type not supported"), + TypeTag::Unparsed => bail!("Unparsed type not supported"), }?; let result = spark_connect::DataType { kind: Some(result) }; diff --git a/tests/connect/test_create_df.py b/tests/connect/test_create_df.py index de2315a84d..e06944e19f 100644 --- a/tests/connect/test_create_df.py +++ b/tests/connect/test_create_df.py @@ -2,11 +2,26 @@ def test_create_df(spark_session): - # Create simple DataFrame + # Create simple DataFrame with single column data = [(1,), (2,), (3,)] df = spark_session.createDataFrame(data, ["id"]) - # Convert to pandas + # Convert to pandas and verify df_pandas = df.toPandas() assert len(df_pandas) == 3, "DataFrame should have 3 rows" assert list(df_pandas["id"]) == [1, 2, 3], "DataFrame should contain expected values" + + # Create DataFrame with float column + float_data = [(1.1,), (2.2,), (3.3,)] + df_float = spark_session.createDataFrame(float_data, ["value"]) + df_float_pandas = df_float.toPandas() + assert len(df_float_pandas) == 3, "Float DataFrame should have 3 rows" + assert list(df_float_pandas["value"]) == [1.1, 2.2, 3.3], "Float DataFrame should contain expected values" + + # Create DataFrame with two numeric columns + two_col_data = [(1, 10), (2, 20), (3, 30)] + df_two = spark_session.createDataFrame(two_col_data, ["num1", "num2"]) + df_two_pandas = df_two.toPandas() + assert len(df_two_pandas) == 3, "Two-column DataFrame should have 3 rows" + assert list(df_two_pandas["num1"]) == [1, 2, 3], "First number column should contain expected values" + assert list(df_two_pandas["num2"]) == [10, 20, 30], "Second number column should contain expected values"