diff --git a/Cargo.lock b/Cargo.lock index a425b34142..3862a3f266 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2206,6 +2206,7 @@ dependencies = [ "daft-plan", "once_cell", "pyo3", + "regex", "rstest", "snafu", "sqlparser", diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 81d7d36ff0..41e9e88338 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -10,6 +10,7 @@ daft-plan = {path = "../daft-plan"} once_cell = {workspace = true} pyo3 = {workspace = true, optional = true} sqlparser = {workspace = true} +regex.workspace = true snafu.workspace = true [dev-dependencies] diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 2f47b5008b..d81796a4a6 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -11,9 +11,10 @@ use daft_functions::numeric::{ceil::ceil, floor::floor}; use daft_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ - ArrayElemTypeDef, BinaryOperator, CastKind, Distinct, ExactNumberInfo, ExcludeSelectItem, - GroupByExpr, Ident, Query, SelectItem, Statement, StructField, Subscript, TableAlias, - TableWithJoins, TimezoneInfo, UnaryOperator, Value, WildcardAdditionalOptions, + ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo, + ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, Statement, StructField, + Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value, + WildcardAdditionalOptions, }, dialect::GenericDialect, parser::{Parser, ParserOptions}, @@ -940,7 +941,171 @@ impl SQLPlanner { SQLExpr::Map(_) => unsupported_sql_err!("MAP"), SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()), SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"), - SQLExpr::Interval(_) => unsupported_sql_err!("INTERVAL"), + SQLExpr::Interval(interval) => { + use regex::Regex; + + /// A private struct represents a single parsed interval unit and its value + #[derive(Debug)] + struct IntervalPart { + count: i64, + unit: DateTimeField, + } + + // Local function to parse interval string to interval parts + fn parse_interval_string(expr: &str) -> Result, PlannerError> { + let expr = expr.trim().trim_matches('\''); + + let re = Regex::new(r"(-?\d+)\s*(year|years|month|months|day|days|hour|hours|minute|minutes|second|seconds|millisecond|milliseconds|microsecond|microseconds|nanosecond|nanoseconds|week|weeks)") + .map_err(|e|PlannerError::invalid_operation(format!("Invalid regex pattern: {}", e)))?; + + let mut parts = Vec::new(); + + for cap in re.captures_iter(expr) { + let count: i64 = cap[1].parse().map_err(|e| { + PlannerError::invalid_operation(format!("Invalid interval count: {e}")) + })?; + + let unit = match &cap[2].to_lowercase()[..] { + "year" | "years" => DateTimeField::Year, + "month" | "months" => DateTimeField::Month, + "week" | "weeks" => DateTimeField::Week(None), + "day" | "days" => DateTimeField::Day, + "hour" | "hours" => DateTimeField::Hour, + "minute" | "minutes" => DateTimeField::Minute, + "second" | "seconds" => DateTimeField::Second, + "millisecond" | "milliseconds" => DateTimeField::Millisecond, + "microsecond" | "microseconds" => DateTimeField::Microsecond, + "nanosecond" | "nanoseconds" => DateTimeField::Nanosecond, + _ => { + return Err(PlannerError::invalid_operation(format!( + "Invalid interval unit: {}", + &cap[2] + ))) + } + }; + + parts.push(IntervalPart { count, unit }); + } + + if parts.is_empty() { + return Err(PlannerError::invalid_operation("Invalid interval format.")); + } + + Ok(parts) + } + + // Local function to convert parts to interval values + fn interval_parts_to_values(parts: Vec) -> (i64, i64, i64) { + let mut total_months = 0i64; + let mut total_days = 0i64; + let mut total_nanos = 0i64; + + for part in parts { + match part.unit { + DateTimeField::Year => total_months += 12 * part.count, + DateTimeField::Month => total_months += part.count, + DateTimeField::Week(_) => total_days += 7 * part.count, + DateTimeField::Day => total_days += part.count, + DateTimeField::Hour => total_nanos += part.count * 3_600_000_000_000, + DateTimeField::Minute => total_nanos += part.count * 60_000_000_000, + DateTimeField::Second => total_nanos += part.count * 1_000_000_000, + DateTimeField::Millisecond | DateTimeField::Milliseconds => { + total_nanos += part.count * 1_000_000; + } + DateTimeField::Microsecond | DateTimeField::Microseconds => { + total_nanos += part.count * 1_000; + } + DateTimeField::Nanosecond | DateTimeField::Nanoseconds => { + total_nanos += part.count; + } + _ => {} + } + } + + (total_months, total_days, total_nanos) + } + + match interval { + // If leading_field is specified, treat it as the old style single-unit interval + // e.g., INTERVAL '12' YEAR + sqlparser::ast::Interval { + value, + leading_field: Some(time_unit), + .. + } => { + let expr = self.plan_expr(value)?; + + let expr = + expr.as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation( + "Interval value must be a string", + ) + })?; + + let count = expr.parse::().map_err(|e| { + PlannerError::unsupported_sql(format!("Invalid interval count: {e}")) + })?; + + let (months, days, nanoseconds) = match time_unit { + DateTimeField::Year => (12 * count, 0, 0), + DateTimeField::Month => (count, 0, 0), + DateTimeField::Week(_) => (0, 7 * count, 0), + DateTimeField::Day => (0, count, 0), + DateTimeField::Hour => (0, 0, count * 3_600_000_000_000), + DateTimeField::Minute => (0, 0, count * 60_000_000_000), + DateTimeField::Second => (0, 0, count * 1_000_000_000), + DateTimeField::Microsecond | DateTimeField::Microseconds => (0, 0, count * 1_000), + DateTimeField::Millisecond | DateTimeField::Milliseconds => (0, 0, count * 1_000_000), + DateTimeField::Nanosecond | DateTimeField::Nanoseconds => (0, 0, count), + _ => return Err(PlannerError::invalid_operation(format!( + "Invalid interval unit: {time_unit}. Expected one of: year, month, week, day, hour, minute, second, millisecond, microsecond, nanosecond" + ))), + }; + + Ok(Arc::new(Expr::Literal(LiteralValue::Interval( + daft_core::datatypes::IntervalValue::new( + months as i32, + days as i32, + nanoseconds, + ), + )))) + } + + // If no leading_field is specified, treat it as the new style multi-unit interval + // e.g., INTERVAL '12 years 3 months 7 days' + sqlparser::ast::Interval { + value, + leading_field: None, + .. + } => { + let expr = self.plan_expr(value)?; + + let expr = + expr.as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation( + "Interval value must be a string", + ) + })?; + + let parts = parse_interval_string(expr) + .map_err(|e| PlannerError::invalid_operation(e.to_string()))?; + + let (months, days, nanoseconds) = interval_parts_to_values(parts); + + Ok(Arc::new(Expr::Literal(LiteralValue::Interval( + daft_core::datatypes::IntervalValue::new( + months as i32, + days as i32, + nanoseconds, + ), + )))) + } + } + } SQLExpr::MatchAgainst { .. } => unsupported_sql_err!("MATCH AGAINST"), SQLExpr::Wildcard => unsupported_sql_err!("WILDCARD"), SQLExpr::QualifiedWildcard(_) => unsupported_sql_err!("QUALIFIED WILDCARD"), diff --git a/tests/sql/test_exprs.py b/tests/sql/test_exprs.py index 2adfb1db31..595a486a31 100644 --- a/tests/sql/test_exprs.py +++ b/tests/sql/test_exprs.py @@ -1,7 +1,10 @@ +import datetime + import pytest import daft -from daft import col +from daft import col, interval +from daft.sql.sql import SQLCatalog def test_nested(): @@ -135,3 +138,78 @@ def test_is_in_edge_cases(): # Test with mixed types in the IN list with pytest.raises(Exception, match="All literals must have the same data type"): daft.sql("SELECT * FROM df WHERE nums IN (1, '2', 3.0)").collect().to_pydict() + + +@pytest.mark.parametrize( + "date_values, ts_values, expected_intervals", + [ + ( + ["2022-01-01", "2020-02-29", "2029-05-15"], + ["2022-01-01 10:00:00", "2020-02-29 23:59:59", "2029-05-15 12:34:56"], + { + "date_add_day": [ + datetime.date(2022, 1, 2), + datetime.date(2020, 3, 1), + datetime.date(2029, 5, 16), + ], + "date_sub_month": [ + datetime.date(2021, 12, 1), + datetime.date(2020, 1, 31), + datetime.date(2029, 4, 14), + ], + "ts_sub_year": [ + datetime.datetime(2021, 1, 1, 10), + datetime.datetime(2019, 2, 28, 23, 59, 59), + datetime.datetime(2028, 5, 15, 12, 34, 56), + ], + "ts_add_hour": [ + datetime.datetime(2022, 1, 1, 11, 0, 0), + datetime.datetime(2020, 3, 1, 0, 59, 59), + datetime.datetime(2029, 5, 15, 13, 34, 56), + ], + "ts_sub_minute": [ + datetime.datetime(2022, 1, 1, 9, 57, 21), + datetime.datetime(2020, 2, 29, 23, 57, 20), + datetime.datetime(2029, 5, 15, 12, 32, 17), + ], + }, + ), + ], +) +def test_interval_comparison(date_values, ts_values, expected_intervals): + # Create DataFrame with date and timestamp columns + df = daft.from_pydict({"date": date_values, "ts": ts_values}).select( + col("date").cast(daft.DataType.date()), col("ts").str.to_datetime("%Y-%m-%d %H:%M:%S") + ) + catalog = SQLCatalog({"test": df}) + + expected_df = ( + df.select( + (col("date") + interval(days=1)).alias("date_add_day"), + (col("date") - interval(months=1)).alias("date_sub_month"), + (col("ts") - interval(years=1, days=0)).alias("ts_sub_year"), + (col("ts") + interval(hours=1)).alias("ts_add_hour"), + (col("ts") - interval(minutes=1, seconds=99)).alias("ts_sub_minute"), + ) + .collect() + .to_pydict() + ) + + actual_sql = ( + daft.sql( + """ + SELECT + date + INTERVAL '1' day AS date_add_day, + date - INTERVAL '1 months' AS date_sub_month, + ts - INTERVAL '1 year 0 days' AS ts_sub_year, + ts + INTERVAL '1' hour AS ts_add_hour, + ts - INTERVAL '1 minutes 99 second' AS ts_sub_minute + FROM test + """, + catalog=catalog, + ) + .collect() + .to_pydict() + ) + + assert expected_df == actual_sql == expected_intervals