Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 4, 2024
1 parent dd29d9f commit 3ee3eda
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 6 deletions.
86 changes: 82 additions & 4 deletions src/daft-connect/src/translation/datatype/codec.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use eyre::{bail, ensure};
use color_eyre::Help;
use eyre::{bail, ensure, eyre};
use serde_json::Value;
use spark_connect::data_type::{Kind, Long};
use spark_connect::data_type::Kind;
use tracing::warn;

#[derive(Debug)]
Expand Down Expand Up @@ -89,11 +90,88 @@ fn deser_helper(
let kind = remove_type(&mut input)?;

let result = match kind {
TypeTag::Null => Ok(Kind::Null(spark_connect::data_type::Null {
type_variation_reference: 0,
})),
TypeTag::Binary => Ok(Kind::Binary(spark_connect::data_type::Binary {
type_variation_reference: 0,
})),
TypeTag::Boolean => Ok(Kind::Boolean(spark_connect::data_type::Boolean {
type_variation_reference: 0,
})),
TypeTag::Byte => Ok(Kind::Byte(spark_connect::data_type::Byte {
type_variation_reference: 0,
})),
TypeTag::Short => Ok(Kind::Short(spark_connect::data_type::Short {
type_variation_reference: 0,
})),
TypeTag::Integer => Ok(Kind::Integer(spark_connect::data_type::Integer {
type_variation_reference: 0,
})),
TypeTag::Long => Ok(Kind::Long(spark_connect::data_type::Long {
type_variation_reference: 0,
})),
TypeTag::Float => Ok(Kind::Float(spark_connect::data_type::Float {
type_variation_reference: 0,
})),
TypeTag::Double => Ok(Kind::Double(spark_connect::data_type::Double {
type_variation_reference: 0,
})),
TypeTag::Decimal => Ok(Kind::Decimal(spark_connect::data_type::Decimal {
scale: None,
precision: None,
type_variation_reference: 0,
})),
TypeTag::String => Ok(Kind::String(spark_connect::data_type::String {
type_variation_reference: 0,
collation: String::new(),
})),
TypeTag::Char => Ok(Kind::Char(spark_connect::data_type::Char {
type_variation_reference: 0,
length: 1,
})),
TypeTag::VarChar => Ok(Kind::VarChar(spark_connect::data_type::VarChar {
type_variation_reference: 0,
length: 0,
})),
TypeTag::Date => Ok(Kind::Date(spark_connect::data_type::Date {
type_variation_reference: 0,
})),
TypeTag::Timestamp => Ok(Kind::Timestamp(spark_connect::data_type::Timestamp {
type_variation_reference: 0,
})),
TypeTag::TimestampNtz => Ok(Kind::TimestampNtz(spark_connect::data_type::TimestampNtz {
type_variation_reference: 0,
})),
TypeTag::CalendarInterval => Ok(Kind::CalendarInterval(
spark_connect::data_type::CalendarInterval {
type_variation_reference: 0,
},
)),
TypeTag::YearMonthInterval => Ok(Kind::YearMonthInterval(
spark_connect::data_type::YearMonthInterval {
type_variation_reference: 0,
start_field: None,
end_field: None,
},
)),
TypeTag::DayTimeInterval => Ok(Kind::DayTimeInterval(
spark_connect::data_type::DayTimeInterval {
type_variation_reference: 0,
start_field: None,
end_field: None,
},
)),
TypeTag::Array => Err(eyre!("Array type not supported"))
.suggestion("Wait until we support arrays in Spark Connect"),
TypeTag::Struct => deser_struct(input),
TypeTag::Long => Ok(Kind::Long(Long {
TypeTag::Map => Err(eyre!("Map type not supported"))
.suggestion("Wait until we support maps in Spark Connect"),
TypeTag::Variant => Ok(Kind::Variant(spark_connect::data_type::Variant {
type_variation_reference: 0,
})),
_ => bail!("unsupported type tag {:?}", kind),
TypeTag::Udt => bail!("UDT type not supported"),
TypeTag::Unparsed => bail!("Unparsed type not supported"),
}?;

let result = spark_connect::DataType { kind: Some(result) };
Expand Down
19 changes: 17 additions & 2 deletions tests/connect/test_create_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,26 @@


def test_create_df(spark_session):
# Create simple DataFrame
# Create simple DataFrame with single column
data = [(1,), (2,), (3,)]
df = spark_session.createDataFrame(data, ["id"])

# Convert to pandas
# Convert to pandas and verify
df_pandas = df.toPandas()
assert len(df_pandas) == 3, "DataFrame should have 3 rows"
assert list(df_pandas["id"]) == [1, 2, 3], "DataFrame should contain expected values"

# Create DataFrame with float column
float_data = [(1.1,), (2.2,), (3.3,)]
df_float = spark_session.createDataFrame(float_data, ["value"])
df_float_pandas = df_float.toPandas()
assert len(df_float_pandas) == 3, "Float DataFrame should have 3 rows"
assert list(df_float_pandas["value"]) == [1.1, 2.2, 3.3], "Float DataFrame should contain expected values"

# Create DataFrame with two numeric columns
two_col_data = [(1, 10), (2, 20), (3, 30)]
df_two = spark_session.createDataFrame(two_col_data, ["num1", "num2"])
df_two_pandas = df_two.toPandas()
assert len(df_two_pandas) == 3, "Two-column DataFrame should have 3 rows"
assert list(df_two_pandas["num1"]) == [1, 2, 3], "First number column should contain expected values"
assert list(df_two_pandas["num2"]) == [10, 20, 30], "Second number column should contain expected values"

0 comments on commit 3ee3eda

Please sign in to comment.