Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent 9af579d commit 54ddc5e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 19 deletions.
12 changes: 2 additions & 10 deletions src/daft-connect/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,27 @@ pub fn to_tree_string(schema: &Schema) -> eyre::Result<String> {
// A helper function to print a field at a given level of indentation.
// level=1 means a single " |-- " prefix, level=2 means
// " | |-- " and so on, mimicking Spark's indentation style.
// A helper function to print a field at a given level of indentation.
fn print_field(
w: &mut String,
field_name: &str,
dtype: &DataType,
nullable: bool,
level: usize,
) -> eyre::Result<()> {
// Construct the prefix for indentation.
// Spark indentation levels:
// level 1: " |-- "
// level 2: " | |-- "
// level n: " |" followed by (4*(n-1)) spaces + "-- "
let indent = if level == 1 {
" |-- ".to_string()
} else {
let spaces = " ".repeat(4 * (level - 1));
format!(" |{}-- ", spaces)
format!(" |{}-- ", " |".repeat(level - 1))
};

// Get a user-friendly string for dtype
let dtype_str = type_to_string(dtype);

writeln!(
w,
"{}{}: {} (nullable = {})",
indent, field_name, dtype_str, nullable
)?;

// If the dtype is a struct, we must print its child fields with increased indentation.
if let DataType::Struct(fields) = dtype {
for field in fields {
print_field(w, &field.name, &field.dtype, true, level + 1)?;
Expand Down
7 changes: 4 additions & 3 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,13 +376,14 @@ impl SparkConnectService for DaftSparkConnectService {
.unwrap()
.build();

let s = plan.display_as(DisplayLevel::Default);
let schema = plan.schema();
let tree_string = display::to_tree_string(&schema).unwrap();

let response = AnalyzePlanResponse {
session_id,
server_side_session_id: String::new(),
result: Some(spark_connect::analyze_plan_response::Result::TreeString(
analyze_plan_response::TreeString { tree_string: s },
result: Some(analyze_plan_response::Result::TreeString(
analyze_plan_response::TreeString { tree_string },
)),
};

Expand Down
9 changes: 3 additions & 6 deletions tests/connect/test_print_schema.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from __future__ import annotations


def test_print_schema(spark_session: object, capsys: object) -> None:
def test_print_schema(spark_session, capsys) -> None:
df = spark_session.range(10)
df.printSchema()

captured = capsys.readouterr()
expected = (
"root\n"
" |-- id: long (nullable = true)\n"
)
expected = "root\n" " |-- id: integer (nullable = true)\n\n"
assert captured.out == expected

0 comments on commit 54ddc5e

Please sign in to comment.