Skip to content

Commit

Permalink
feat(connect): df.show (#3560)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Dec 18, 2024
1 parent 3a3707a commit ca4d3f7
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 5 deletions.
1 change: 0 additions & 1 deletion src/daft-connect/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ impl Session {
pub fn new(id: String) -> Self {
let server_side_session_id = Uuid::new_v4();
let server_side_session_id = server_side_session_id.to_string();

Self {
config_values: Default::default(),
id,
Expand Down
105 changes: 101 additions & 4 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
use std::sync::Arc;

use common_daft_config::DaftExecutionConfig;
use daft_core::prelude::Schema;
use daft_dsl::LiteralValue;
use daft_local_execution::NativeExecutor;
use daft_logical_plan::LogicalPlanBuilder;
use daft_micropartition::partitioning::InMemoryPartitionSetCache;
use daft_micropartition::{
partitioning::{
InMemoryPartitionSetCache, MicroPartitionSet, PartitionCacheEntry, PartitionMetadata,
PartitionSet, PartitionSetCache,
},
MicroPartition,
};
use daft_table::Table;
use eyre::{bail, Context};
use spark_connect::{relation::RelType, Limit, Relation};
use futures::TryStreamExt;
use spark_connect::{relation::RelType, Limit, Relation, ShowString};
use tracing::warn;

mod aggregate;
Expand All @@ -22,6 +36,35 @@ impl SparkAnalyzer<'_> {
pub fn new(pset: &InMemoryPartitionSetCache) -> SparkAnalyzer {
SparkAnalyzer { psets: pset }
}
pub fn create_in_memory_scan(
&self,
plan_id: usize,
schema: Arc<Schema>,
tables: Vec<Table>,
) -> eyre::Result<LogicalPlanBuilder> {
let partition_key = uuid::Uuid::new_v4().to_string();

let pset = Arc::new(MicroPartitionSet::from_tables(plan_id, tables)?);

let PartitionMetadata {
num_rows,
size_bytes,
} = pset.metadata();
let num_partitions = pset.num_partitions();

self.psets.put_partition_set(&partition_key, &pset);

let cache_entry = PartitionCacheEntry::new_rust(partition_key.clone(), pset);

Ok(LogicalPlanBuilder::in_memory_scan(
&partition_key,
cache_entry,
schema,
num_partitions,
size_bytes,
num_rows,
)?)
}

pub async fn to_logical_plan(&self, relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
let Some(common) = relation.common else {
Expand Down Expand Up @@ -78,12 +121,18 @@ impl SparkAnalyzer<'_> {
.filter(*f)
.await
.wrap_err("Failed to apply filter to logical plan"),
RelType::ShowString(ss) => {
let Some(plan_id) = common.plan_id else {
bail!("Plan ID is required for LocalRelation");
};
self.show_string(plan_id, *ss)
.await
.wrap_err("Failed to show string")
}
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
}

impl SparkAnalyzer<'_> {
async fn limit(&self, limit: Limit) -> eyre::Result<LogicalPlanBuilder> {
let Limit { input, limit } = limit;

Expand All @@ -96,4 +145,52 @@ impl SparkAnalyzer<'_> {
plan.limit(i64::from(limit), false)
.wrap_err("Failed to apply limit to logical plan")
}

/// right now this just naively applies a limit to the logical plan
/// In the future, we want this to more closely match our daft implementation
async fn show_string(
&self,
plan_id: i64,
show_string: ShowString,
) -> eyre::Result<LogicalPlanBuilder> {
let ShowString {
input,
num_rows,
truncate: _,
vertical,
} = show_string;

if vertical {
bail!("Vertical show string is not supported");
}

let Some(input) = input else {
bail!("input must be set");
};

let plan = Box::pin(self.to_logical_plan(*input)).await?;
let plan = plan.limit(num_rows as i64, true)?;

let optimized_plan = tokio::task::spawn_blocking(move || plan.optimize())
.await
.unwrap()?;

let cfg = Arc::new(DaftExecutionConfig::default());
let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?;
let result_stream = native_executor.run(self.psets, cfg, None)?.into_stream();
let batch = result_stream.try_collect::<Vec<_>>().await?;
let single_batch = MicroPartition::concat(batch)?;
let tbls = single_batch.get_tables()?;
let tbl = Table::concat(&tbls)?;
let output = tbl.to_comfy_table(None).to_string();

let s = LiteralValue::Utf8(output)
.into_single_value_series()?
.rename("show_string");

let tbl = Table::from_nonempty_columns(vec![s])?;
let schema = tbl.schema.clone();

self.create_in_memory_scan(plan_id as _, schema, vec![tbl])
}
}
6 changes: 6 additions & 0 deletions src/daft-dsl/src/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,12 @@ pub fn null_lit() -> ExprRef {
Arc::new(Expr::Literal(LiteralValue::Null))
}

impl LiteralValue {
pub fn into_single_value_series(self) -> DaftResult<Series> {
literals_to_series(&[self])
}
}

/// Convert a slice of literals to a series.
/// This function will return an error if the literals are not all the same type
pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult<Series> {
Expand Down
9 changes: 9 additions & 0 deletions tests/connect/test_show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations


def test_show(spark_session):
df = spark_session.range(10)
try:
df.show()
except Exception as e:
assert False, e

0 comments on commit ca4d3f7

Please sign in to comment.