From cbe9d3b1697122688007805531063f6d0be08a48 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Thu, 21 Nov 2024 08:42:56 -0800 Subject: [PATCH] [FEAT] connect: add `df.limit` and `df.first` (#3309) --- src/daft-connect/src/translation/logical_plan.rs | 15 ++++++++++++++- tests/connect/test_limit_simple.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 tests/connect/test_limit_simple.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 947e0cd0d3..93c9e9bd4a 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -1,6 +1,6 @@ use daft_logical_plan::LogicalPlanBuilder; use eyre::{bail, Context}; -use spark_connect::{relation::RelType, Relation}; +use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; @@ -19,6 +19,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { }; match rel_type { + RelType::Limit(l) => limit(*l).wrap_err("Failed to apply limit to logical plan"), RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"), RelType::Project(p) => project(*p).wrap_err("Failed to apply project to logical plan"), RelType::Aggregate(a) => { @@ -27,3 +28,15 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { plan => bail!("Unsupported relation type: {plan:?}"), } } + +fn limit(limit: Limit) -> eyre::Result { + let Limit { input, limit } = limit; + + let Some(input) = input else { + bail!("input must be set"); + }; + + let plan = to_logical_plan(*input)?.limit(i64::from(limit), false)?; // todo: eager or no + + Ok(plan) +} diff --git a/tests/connect/test_limit_simple.py b/tests/connect/test_limit_simple.py new file mode 100644 index 0000000000..d5f2c97dae --- /dev/null +++ b/tests/connect/test_limit_simple.py @@ -0,0 +1,14 @@ +from __future__ import annotations + + +def test_range_first(spark_session): + spark_range = spark_session.range(10) + first_row = spark_range.first() + assert first_row["id"] == 0, "First row should have id=0" + + +def test_range_limit(spark_session): + spark_range = spark_session.range(10) + limited_df = spark_range.limit(5).toPandas() + assert len(limited_df) == 5, "Limited DataFrame should have 5 rows" + assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4"