Skip to content

Commit

Permalink
feat: add limit and first
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 21, 2024
1 parent 3394a66 commit f9d442f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use daft_logical_plan::LogicalPlanBuilder;
use eyre::{bail, Context};
use spark_connect::{relation::RelType, Relation};
use spark_connect::{relation::RelType, Limit, Relation};
use tracing::warn;

use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range};
Expand All @@ -19,6 +19,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
};

match rel_type {
RelType::Limit(l) => limit(*l).wrap_err("Failed to apply limit to logical plan"),
RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"),
RelType::Project(p) => project(*p).wrap_err("Failed to apply project to logical plan"),
RelType::Aggregate(a) => {
Expand All @@ -27,3 +28,15 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
plan => bail!("Unsupported relation type: {plan:?}"),
}
}

fn limit(limit: Limit) -> eyre::Result<LogicalPlanBuilder> {
let Limit { input, limit } = limit;

let Some(input) = input else {
bail!("input must be set");
};

let plan = to_logical_plan(*input)?.limit(i64::from(limit), false)?; // todo: eager or no

Ok(plan)
}
13 changes: 13 additions & 0 deletions tests/connect/test_range_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,16 @@ def test_range_operation(spark_session):
# Verify the DataFrame has expected values
assert len(pandas_df) == 10, "DataFrame should have 10 rows"
assert list(pandas_df["id"]) == list(range(10)), "DataFrame should contain values 0-9"


def test_range_first(spark_session):
spark_range = spark_session.range(10)
first_row = spark_range.first()
assert first_row["id"] == 0, "First row should have id=0"


def test_range_limit(spark_session):
spark_range = spark_session.range(10)
limited_df = spark_range.limit(5).toPandas()
assert len(limited_df) == 5, "Limited DataFrame should have 5 rows"
assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4"

0 comments on commit f9d442f

Please sign in to comment.