From 61b684213687f00913f73461f408ce9d7a4a615a Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Wed, 6 Nov 2024 16:54:45 -0600 Subject: [PATCH] [FEAT]: sql "extract" temporal function (#3188) --- src/daft-sql/src/planner.rs | 21 +++++++++++++++- tests/sql/test_temporal_exprs.py | 41 ++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 42e9460cad..239be4845d 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1121,7 +1121,26 @@ impl SQLPlanner { SQLExpr::Convert { .. } => unsupported_sql_err!("CONVERT"), SQLExpr::Cast { .. } => unsupported_sql_err!("CAST"), SQLExpr::AtTimeZone { .. } => unsupported_sql_err!("AT TIME ZONE"), - SQLExpr::Extract { .. } => unsupported_sql_err!("EXTRACT"), + SQLExpr::Extract { + field, + syntax: _, + expr, + } => { + use daft_functions::temporal::{self as dt}; + let expr = self.plan_expr(expr)?; + + match field { + DateTimeField::Year => Ok(dt::dt_year(expr)), + DateTimeField::Month => Ok(dt::dt_month(expr)), + DateTimeField::Day => Ok(dt::dt_day(expr)), + DateTimeField::DayOfWeek => Ok(dt::dt_day_of_week(expr)), + DateTimeField::Date => Ok(dt::dt_date(expr)), + DateTimeField::Hour => Ok(dt::dt_hour(expr)), + DateTimeField::Minute => Ok(dt::dt_minute(expr)), + DateTimeField::Second => Ok(dt::dt_second(expr)), + other => unsupported_sql_err!("EXTRACT ({other})"), + } + } SQLExpr::Ceil { expr, .. } => Ok(ceil(self.plan_expr(expr)?)), SQLExpr::Floor { expr, .. } => Ok(floor(self.plan_expr(expr)?)), SQLExpr::Position { .. } => unsupported_sql_err!("POSITION"), diff --git a/tests/sql/test_temporal_exprs.py b/tests/sql/test_temporal_exprs.py index 9f8d6640c0..d475850839 100644 --- a/tests/sql/test_temporal_exprs.py +++ b/tests/sql/test_temporal_exprs.py @@ -47,3 +47,44 @@ def test_temporals(): ).collect() assert actual.to_pydict() == expected.to_pydict() + + +def test_extract(): + df = daft.from_pydict( + { + "datetimes": [ + datetime.datetime(2021, 1, 1, 23, 59, 58), + datetime.datetime(2021, 1, 2, 0, 0, 0), + datetime.datetime(2021, 1, 2, 1, 2, 3), + datetime.datetime(2021, 1, 2, 1, 2, 3), + datetime.datetime(1999, 1, 1, 1, 1, 1), + None, + ] + } + ) + + expected = df.select( + daft.col("datetimes").dt.date().alias("date"), + daft.col("datetimes").dt.day().alias("day"), + daft.col("datetimes").dt.day_of_week().alias("day_of_week"), + daft.col("datetimes").dt.hour().alias("hour"), + daft.col("datetimes").dt.minute().alias("minute"), + daft.col("datetimes").dt.month().alias("month"), + daft.col("datetimes").dt.second().alias("second"), + daft.col("datetimes").dt.year().alias("year"), + ).collect() + + actual = daft.sql(""" + SELECT + extract(date from datetimes) as date, + extract(day from datetimes) as day, + extract(dayofweek from datetimes) as day_of_week, + extract(hour from datetimes) as hour, + extract(minute from datetimes) as minute, + extract(month from datetimes) as month, + extract(second from datetimes) as second, + extract(year from datetimes) as year, + FROM df + """).collect() + + assert actual.to_pydict() == expected.to_pydict()