Skip to content

Commit

Permalink
[FEAT] connect: support sample
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 5, 2024
1 parent 3038c0f commit 914b637
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, local_relation::local_relation, project::project, range::range,
set_op::set_op, to_df::to_df, with_columns::with_columns,
sample::sample, set_op::set_op, to_df::to_df, with_columns::with_columns,
};

mod aggregate;
mod local_relation;
mod project;
mod range;
mod sample;
mod set_op;
mod to_df;
mod with_columns;
Expand Down Expand Up @@ -58,6 +59,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<Plan> {
RelType::LocalRelation(l) => {
local_relation(l).wrap_err("Failed to apply local_relation to logical plan")
}
RelType::Sample(s) => sample(*s).wrap_err("Failed to apply sample to logical plan"),
RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
Expand Down
41 changes: 41 additions & 0 deletions src/daft-connect/src/translation/logical_plan/sample.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use eyre::{bail, WrapErr};
use tracing::warn;

use crate::translation::{to_logical_plan, Plan};

pub fn sample(sample: spark_connect::Sample) -> eyre::Result<Plan> {
let spark_connect::Sample {
input,
lower_bound,
upper_bound,
with_replacement,
seed,
deterministic_order,
} = sample;

let Some(input) = input else {
bail!("Input is required");
};

let mut plan = to_logical_plan(*input)?;

// Calculate fraction from bounds
// todo: is this correct?
let fraction = upper_bound - lower_bound;

let with_replacement = with_replacement.unwrap_or(false);

// we do not care about sign change
let seed = seed.map(|seed| seed as u64);

if deterministic_order {
warn!("Deterministic order is not yet supported");
}

plan.builder = plan
.builder
.sample(fraction, with_replacement, seed)
.wrap_err("Failed to apply sample to logical plan")?;

Ok(plan)
}
18 changes: 18 additions & 0 deletions tests/connect/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations


def test_sample(spark_session):
# Create a range DataFrame
df = spark_session.range(100)

# Test sample with fraction
sampled_df = df.sample(fraction=0.1, seed=42)
sampled_rows = sampled_df.collect()

# Verify sample size is roughly 10% of original
sample_size = len(sampled_rows)
assert 5 <= sample_size <= 15, f"Sample size {sample_size} should be roughly 10 rows"

# Verify sampled values are within original range
for row in sampled_rows:
assert 0 <= row["id"] < 100, f"Sampled value {row['id']} outside valid range"

0 comments on commit 914b637

Please sign in to comment.