Skip to content

Commit

Permalink
[FEAT] Support SQL INTERVAL (#3146)
Browse files Browse the repository at this point in the history
## The Rationales

Thanks to the [great
work](#3018) from
@universalmind303 , Daft now supports `INTERVAL` type exposed from
`arrow2`.

Beyond DataFrame supports, this PR aims to unlock SQL simple `INTERVAL`
usage in SQL syntax, mainly copied from
[planner.rs](https://github.com/sgl-project/sglang/pull/1790/files#diff-ea02b059cdabc0939616c35c6566dbcf980a5794306dedd241c2823afd9b2db2).

Notes: This naive impl doesn't fully support complex interval scenarios,
like leap year or relative duration addition and subtraction. We might
need more carefully handled logic as the follow ups.

---------

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 authored Oct 30, 2024
1 parent e84ed5b commit 701a011
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 5 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ daft-plan = {path = "../daft-plan"}
once_cell = {workspace = true}
pyo3 = {workspace = true, optional = true}
sqlparser = {workspace = true}
regex.workspace = true
snafu.workspace = true

[dev-dependencies]
Expand Down
173 changes: 169 additions & 4 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use daft_functions::numeric::{ceil::ceil, floor::floor};
use daft_plan::{LogicalPlanBuilder, LogicalPlanRef};
use sqlparser::{
ast::{
ArrayElemTypeDef, BinaryOperator, CastKind, Distinct, ExactNumberInfo, ExcludeSelectItem,
GroupByExpr, Ident, Query, SelectItem, Statement, StructField, Subscript, TableAlias,
TableWithJoins, TimezoneInfo, UnaryOperator, Value, WildcardAdditionalOptions,
ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo,
ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, Statement, StructField,
Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value,
WildcardAdditionalOptions,
},
dialect::GenericDialect,
parser::{Parser, ParserOptions},
Expand Down Expand Up @@ -940,7 +941,171 @@ impl SQLPlanner {
SQLExpr::Map(_) => unsupported_sql_err!("MAP"),
SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()),
SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"),
SQLExpr::Interval(_) => unsupported_sql_err!("INTERVAL"),
SQLExpr::Interval(interval) => {
use regex::Regex;

/// A private struct represents a single parsed interval unit and its value
#[derive(Debug)]
struct IntervalPart {
count: i64,
unit: DateTimeField,
}

// Local function to parse interval string to interval parts
fn parse_interval_string(expr: &str) -> Result<Vec<IntervalPart>, PlannerError> {
let expr = expr.trim().trim_matches('\'');

let re = Regex::new(r"(-?\d+)\s*(year|years|month|months|day|days|hour|hours|minute|minutes|second|seconds|millisecond|milliseconds|microsecond|microseconds|nanosecond|nanoseconds|week|weeks)")
.map_err(|e|PlannerError::invalid_operation(format!("Invalid regex pattern: {}", e)))?;

let mut parts = Vec::new();

for cap in re.captures_iter(expr) {
let count: i64 = cap[1].parse().map_err(|e| {
PlannerError::invalid_operation(format!("Invalid interval count: {e}"))
})?;

let unit = match &cap[2].to_lowercase()[..] {
"year" | "years" => DateTimeField::Year,
"month" | "months" => DateTimeField::Month,
"week" | "weeks" => DateTimeField::Week(None),
"day" | "days" => DateTimeField::Day,
"hour" | "hours" => DateTimeField::Hour,
"minute" | "minutes" => DateTimeField::Minute,
"second" | "seconds" => DateTimeField::Second,
"millisecond" | "milliseconds" => DateTimeField::Millisecond,
"microsecond" | "microseconds" => DateTimeField::Microsecond,
"nanosecond" | "nanoseconds" => DateTimeField::Nanosecond,
_ => {
return Err(PlannerError::invalid_operation(format!(
"Invalid interval unit: {}",
&cap[2]
)))
}
};

parts.push(IntervalPart { count, unit });
}

if parts.is_empty() {
return Err(PlannerError::invalid_operation("Invalid interval format."));
}

Ok(parts)
}

// Local function to convert parts to interval values
fn interval_parts_to_values(parts: Vec<IntervalPart>) -> (i64, i64, i64) {
let mut total_months = 0i64;
let mut total_days = 0i64;
let mut total_nanos = 0i64;

for part in parts {
match part.unit {
DateTimeField::Year => total_months += 12 * part.count,
DateTimeField::Month => total_months += part.count,
DateTimeField::Week(_) => total_days += 7 * part.count,
DateTimeField::Day => total_days += part.count,
DateTimeField::Hour => total_nanos += part.count * 3_600_000_000_000,
DateTimeField::Minute => total_nanos += part.count * 60_000_000_000,
DateTimeField::Second => total_nanos += part.count * 1_000_000_000,
DateTimeField::Millisecond | DateTimeField::Milliseconds => {
total_nanos += part.count * 1_000_000;
}
DateTimeField::Microsecond | DateTimeField::Microseconds => {
total_nanos += part.count * 1_000;
}
DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
total_nanos += part.count;
}
_ => {}
}
}

(total_months, total_days, total_nanos)
}

match interval {
// If leading_field is specified, treat it as the old style single-unit interval
// e.g., INTERVAL '12' YEAR
sqlparser::ast::Interval {
value,
leading_field: Some(time_unit),
..
} => {
let expr = self.plan_expr(value)?;

let expr =
expr.as_literal()
.and_then(|lit| lit.as_str())
.ok_or_else(|| {
PlannerError::invalid_operation(
"Interval value must be a string",
)
})?;

let count = expr.parse::<i64>().map_err(|e| {
PlannerError::unsupported_sql(format!("Invalid interval count: {e}"))
})?;

let (months, days, nanoseconds) = match time_unit {
DateTimeField::Year => (12 * count, 0, 0),
DateTimeField::Month => (count, 0, 0),
DateTimeField::Week(_) => (0, 7 * count, 0),
DateTimeField::Day => (0, count, 0),
DateTimeField::Hour => (0, 0, count * 3_600_000_000_000),
DateTimeField::Minute => (0, 0, count * 60_000_000_000),
DateTimeField::Second => (0, 0, count * 1_000_000_000),
DateTimeField::Microsecond | DateTimeField::Microseconds => (0, 0, count * 1_000),
DateTimeField::Millisecond | DateTimeField::Milliseconds => (0, 0, count * 1_000_000),
DateTimeField::Nanosecond | DateTimeField::Nanoseconds => (0, 0, count),
_ => return Err(PlannerError::invalid_operation(format!(
"Invalid interval unit: {time_unit}. Expected one of: year, month, week, day, hour, minute, second, millisecond, microsecond, nanosecond"
))),
};

Ok(Arc::new(Expr::Literal(LiteralValue::Interval(
daft_core::datatypes::IntervalValue::new(
months as i32,
days as i32,
nanoseconds,
),
))))
}

// If no leading_field is specified, treat it as the new style multi-unit interval
// e.g., INTERVAL '12 years 3 months 7 days'
sqlparser::ast::Interval {
value,
leading_field: None,
..
} => {
let expr = self.plan_expr(value)?;

let expr =
expr.as_literal()
.and_then(|lit| lit.as_str())
.ok_or_else(|| {
PlannerError::invalid_operation(
"Interval value must be a string",
)
})?;

let parts = parse_interval_string(expr)
.map_err(|e| PlannerError::invalid_operation(e.to_string()))?;

let (months, days, nanoseconds) = interval_parts_to_values(parts);

Ok(Arc::new(Expr::Literal(LiteralValue::Interval(
daft_core::datatypes::IntervalValue::new(
months as i32,
days as i32,
nanoseconds,
),
))))
}
}
}
SQLExpr::MatchAgainst { .. } => unsupported_sql_err!("MATCH AGAINST"),
SQLExpr::Wildcard => unsupported_sql_err!("WILDCARD"),
SQLExpr::QualifiedWildcard(_) => unsupported_sql_err!("QUALIFIED WILDCARD"),
Expand Down
80 changes: 79 additions & 1 deletion tests/sql/test_exprs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import datetime

import pytest

import daft
from daft import col
from daft import col, interval
from daft.sql.sql import SQLCatalog


def test_nested():
Expand Down Expand Up @@ -135,3 +138,78 @@ def test_is_in_edge_cases():
# Test with mixed types in the IN list
with pytest.raises(Exception, match="All literals must have the same data type"):
daft.sql("SELECT * FROM df WHERE nums IN (1, '2', 3.0)").collect().to_pydict()


@pytest.mark.parametrize(
"date_values, ts_values, expected_intervals",
[
(
["2022-01-01", "2020-02-29", "2029-05-15"],
["2022-01-01 10:00:00", "2020-02-29 23:59:59", "2029-05-15 12:34:56"],
{
"date_add_day": [
datetime.date(2022, 1, 2),
datetime.date(2020, 3, 1),
datetime.date(2029, 5, 16),
],
"date_sub_month": [
datetime.date(2021, 12, 1),
datetime.date(2020, 1, 31),
datetime.date(2029, 4, 14),
],
"ts_sub_year": [
datetime.datetime(2021, 1, 1, 10),
datetime.datetime(2019, 2, 28, 23, 59, 59),
datetime.datetime(2028, 5, 15, 12, 34, 56),
],
"ts_add_hour": [
datetime.datetime(2022, 1, 1, 11, 0, 0),
datetime.datetime(2020, 3, 1, 0, 59, 59),
datetime.datetime(2029, 5, 15, 13, 34, 56),
],
"ts_sub_minute": [
datetime.datetime(2022, 1, 1, 9, 57, 21),
datetime.datetime(2020, 2, 29, 23, 57, 20),
datetime.datetime(2029, 5, 15, 12, 32, 17),
],
},
),
],
)
def test_interval_comparison(date_values, ts_values, expected_intervals):
# Create DataFrame with date and timestamp columns
df = daft.from_pydict({"date": date_values, "ts": ts_values}).select(
col("date").cast(daft.DataType.date()), col("ts").str.to_datetime("%Y-%m-%d %H:%M:%S")
)
catalog = SQLCatalog({"test": df})

expected_df = (
df.select(
(col("date") + interval(days=1)).alias("date_add_day"),
(col("date") - interval(months=1)).alias("date_sub_month"),
(col("ts") - interval(years=1, days=0)).alias("ts_sub_year"),
(col("ts") + interval(hours=1)).alias("ts_add_hour"),
(col("ts") - interval(minutes=1, seconds=99)).alias("ts_sub_minute"),
)
.collect()
.to_pydict()
)

actual_sql = (
daft.sql(
"""
SELECT
date + INTERVAL '1' day AS date_add_day,
date - INTERVAL '1 months' AS date_sub_month,
ts - INTERVAL '1 year 0 days' AS ts_sub_year,
ts + INTERVAL '1' hour AS ts_add_hour,
ts - INTERVAL '1 minutes 99 second' AS ts_sub_minute
FROM test
""",
catalog=catalog,
)
.collect()
.to_pydict()
)

assert expected_df == actual_sql == expected_intervals

0 comments on commit 701a011

Please sign in to comment.