From 4bdd51d5a7eae38d8925c69196a267bf61eba94f Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 20:56:54 -0800 Subject: [PATCH] [FEAT] connect: support `sample` --- .../src/translation/logical_plan.rs | 6 +++ .../src/translation/logical_plan/sample.rs | 41 +++++++++++++++++++ tests/connect/test_sample.py | 18 ++++++++ 3 files changed, 65 insertions(+) create mode 100644 src/daft-connect/src/translation/logical_plan/sample.rs create mode 100644 tests/connect/test_sample.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 92d2bce6c3..d412454c3b 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -18,6 +18,8 @@ use futures::TryStreamExt; use spark_connect::{relation::RelType, Limit, Relation, ShowString}; use tracing::warn; +use crate::translation::logical_plan::{sample::sample, set_op::set_op}; + mod aggregate; mod drop; mod filter; @@ -25,6 +27,7 @@ mod local_relation; mod project; mod range; mod read; +mod sample; mod set_op; mod to_df; mod with_columns; @@ -133,6 +136,9 @@ impl SparkAnalyzer<'_> { RelType::SetOp(s) => set_op(*s) .await .wrap_err("Failed to apply set_op to logical plan"), + RelType::Sample(s) => sample(*s) + .await + .wrap_err("Failed to apply sample to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/sample.rs b/src/daft-connect/src/translation/logical_plan/sample.rs new file mode 100644 index 0000000000..ee55bd77e3 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/sample.rs @@ -0,0 +1,41 @@ +use eyre::{bail, WrapErr}; +use tracing::warn; + +use crate::translation::{to_logical_plan, Plan}; + +pub async fn sample(sample: spark_connect::Sample) -> eyre::Result { + let spark_connect::Sample { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + deterministic_order, + } = sample; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let mut plan = Box::pin(to_logical_plan(*input)).await?; + + // Calculate fraction from bounds + // todo: is this correct? + let fraction = upper_bound - lower_bound; + + let with_replacement = with_replacement.unwrap_or(false); + + // we do not care about sign change + let seed = seed.map(|seed| seed as u64); + + if deterministic_order { + warn!("Deterministic order is not yet supported"); + } + + plan.builder = plan + .builder + .sample(fraction, with_replacement, seed) + .wrap_err("Failed to apply sample to logical plan")?; + + Ok(plan) +} diff --git a/tests/connect/test_sample.py b/tests/connect/test_sample.py new file mode 100644 index 0000000000..c7bd4df86e --- /dev/null +++ b/tests/connect/test_sample.py @@ -0,0 +1,18 @@ +from __future__ import annotations + + +def test_sample(spark_session): + # Create a range DataFrame + df = spark_session.range(100) + + # Test sample with fraction + sampled_df = df.sample(fraction=0.1, seed=42) + sampled_rows = sampled_df.collect() + + # Verify sample size is roughly 10% of original + sample_size = len(sampled_rows) + assert 5 <= sample_size <= 15, f"Sample size {sample_size} should be roughly 10 rows" + + # Verify sampled values are within original range + for row in sampled_rows: + assert 0 <= row["id"] < 100, f"Sampled value {row['id']} outside valid range"