diff --git a/src/daft-connect/src/display.rs b/src/daft-connect/src/display.rs index 1bb4f755e2..1a179705d7 100644 --- a/src/daft-connect/src/display.rs +++ b/src/daft-connect/src/display.rs @@ -16,6 +16,7 @@ pub fn to_tree_string(schema: &Schema) -> eyre::Result { // A helper function to print a field at a given level of indentation. // level=1 means a single " |-- " prefix, level=2 means // " | |-- " and so on, mimicking Spark's indentation style. +// A helper function to print a field at a given level of indentation. fn print_field( w: &mut String, field_name: &str, @@ -23,28 +24,19 @@ fn print_field( nullable: bool, level: usize, ) -> eyre::Result<()> { - // Construct the prefix for indentation. - // Spark indentation levels: - // level 1: " |-- " - // level 2: " | |-- " - // level n: " |" followed by (4*(n-1)) spaces + "-- " let indent = if level == 1 { " |-- ".to_string() } else { - let spaces = " ".repeat(4 * (level - 1)); - format!(" |{}-- ", spaces) + format!(" |{}-- ", " |".repeat(level - 1)) }; - // Get a user-friendly string for dtype let dtype_str = type_to_string(dtype); - writeln!( w, "{}{}: {} (nullable = {})", indent, field_name, dtype_str, nullable )?; - // If the dtype is a struct, we must print its child fields with increased indentation. if let DataType::Struct(fields) = dtype { for field in fields { print_field(w, &field.name, &field.dtype, true, level + 1)?; diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index a4f287a26f..1a25712f47 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -376,13 +376,14 @@ impl SparkConnectService for DaftSparkConnectService { .unwrap() .build(); - let s = plan.display_as(DisplayLevel::Default); + let schema = plan.schema(); + let tree_string = display::to_tree_string(&schema).unwrap(); let response = AnalyzePlanResponse { session_id, server_side_session_id: String::new(), - result: Some(spark_connect::analyze_plan_response::Result::TreeString( - analyze_plan_response::TreeString { tree_string: s }, + result: Some(analyze_plan_response::Result::TreeString( + analyze_plan_response::TreeString { tree_string }, )), }; diff --git a/tests/connect/test_print_schema.py b/tests/connect/test_print_schema.py index 5a58f2c95e..2400090374 100644 --- a/tests/connect/test_print_schema.py +++ b/tests/connect/test_print_schema.py @@ -1,13 +1,10 @@ from __future__ import annotations -def test_print_schema(spark_session: object, capsys: object) -> None: +def test_print_schema(spark_session, capsys) -> None: df = spark_session.range(10) df.printSchema() - + captured = capsys.readouterr() - expected = ( - "root\n" - " |-- id: long (nullable = true)\n" - ) + expected = "root\n" " |-- id: integer (nullable = true)\n\n" assert captured.out == expected