From 593fc05a1da35bfd799728b8149e1a85889f94da Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Tue, 19 Sep 2023 19:33:03 -0700 Subject: [PATCH] Add .dt.date() expression --- daft/daft.pyi | 2 + daft/expressions/expressions.py | 11 +++++ daft/series.py | 3 ++ docs/source/api_docs/expressions.rst | 1 + src/daft-core/src/python/series.rs | 4 ++ src/daft-core/src/series/ops/date.rs | 39 ++++++++++++----- src/daft-dsl/src/functions/temporal/date.rs | 47 +++++++++++++++++++++ src/daft-dsl/src/functions/temporal/mod.rs | 13 +++++- src/daft-dsl/src/python.rs | 5 +++ tests/expressions/typing/conftest.py | 6 ++- tests/expressions/typing/test_dt.py | 1 + 11 files changed, 120 insertions(+), 12 deletions(-) create mode 100644 src/daft-dsl/src/functions/temporal/date.rs diff --git a/daft/daft.pyi b/daft/daft.pyi index 1957c4ac3a..e96b2eaa6f 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -537,6 +537,7 @@ class PyExpr: def __setstate__(self, state: Any) -> None: ... def __getstate__(self) -> Any: ... def is_nan(self) -> PyExpr: ... + def dt_date(self) -> PyExpr: ... def dt_day(self) -> PyExpr: ... def dt_month(self) -> PyExpr: ... def dt_year(self) -> PyExpr: ... @@ -608,6 +609,7 @@ class PySeries: def utf8_contains(self, pattern: PySeries) -> PySeries: ... def utf8_length(self) -> PySeries: ... def is_nan(self) -> PySeries: ... + def dt_date(self) -> PySeries: ... def dt_day(self) -> PySeries: ... def dt_month(self) -> PySeries: ... def dt_year(self) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index be70827895..ef667b7db7 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -489,6 +489,17 @@ def is_nan(self) -> Expression: class ExpressionDatetimeNamespace(ExpressionNamespace): + def date(self) -> Expression: + """Retrieves the date for a datetime column + + Example: + >>> col("x").dt.date() + + Returns: + Expression: a Date expression + """ + return Expression._from_pyexpr(self._expr.dt_date()) + def day(self) -> Expression: """Retrieves the day for a datetime column diff --git a/daft/series.py b/daft/series.py index 84b8af6da8..d81196c391 100644 --- a/daft/series.py +++ b/daft/series.py @@ -546,6 +546,9 @@ def length(self) -> Series: class SeriesDateNamespace(SeriesNamespace): + def date(self) -> Series: + return Series._from_pyseries(self._series.dt_date()) + def day(self) -> Series: return Series._from_pyseries(self._series.dt_day()) diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index d00f0c4d13..4a8844386a 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -125,6 +125,7 @@ Example: ``e.dt.day()`` daft.expressions.expressions.ExpressionDatetimeNamespace.month daft.expressions.expressions.ExpressionDatetimeNamespace.year daft.expressions.expressions.ExpressionDatetimeNamespace.day_of_week + daft.expressions.expressions.ExpressionDatetimeNamespace.date .. _api-expressions-urls: diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 3e08961db3..a388585940 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -255,6 +255,10 @@ impl PySeries { Ok(self.series.is_nan()?.into()) } + pub fn dt_date(&self) -> PyResult { + Ok(self.series.dt_date()?.into()) + } + pub fn dt_day(&self) -> PyResult { Ok(self.series.dt_day()?.into()) } diff --git a/src/daft-core/src/series/ops/date.rs b/src/daft-core/src/series/ops/date.rs index 8bbe756cfc..68f33bb450 100644 --- a/src/daft-core/src/series/ops/date.rs +++ b/src/daft-core/src/series/ops/date.rs @@ -7,6 +7,20 @@ use crate::{ use common_error::{DaftError, DaftResult}; impl Series { + pub fn dt_date(&self) -> DaftResult { + match self.data_type() { + DataType::Date => Ok(self.clone()), + DataType::Timestamp(..) => { + let ts_array = self.downcast::()?; + Ok(ts_array.date()?.into_series()) + } + _ => Err(DaftError::ComputeError(format!( + "Can only run date() operation on temporal types, got {}", + self.data_type() + ))), + } + } + pub fn dt_day(&self) -> DaftResult { match self.data_type() { DataType::Date => { @@ -18,7 +32,7 @@ impl Series { Ok(ts_array.date()?.day()?.into_series()) } _ => Err(DaftError::ComputeError(format!( - "Can only run day() operation on DateType, got {}", + "Can only run day() operation on temporal types, got {}", self.data_type() ))), } @@ -35,7 +49,7 @@ impl Series { Ok(ts_array.date()?.month()?.into_series()) } _ => Err(DaftError::ComputeError(format!( - "Can only run month() operation on DateType, got {}", + "Can only run month() operation on temporal types, got {}", self.data_type() ))), } @@ -52,21 +66,26 @@ impl Series { Ok(ts_array.date()?.year()?.into_series()) } _ => Err(DaftError::ComputeError(format!( - "Can only run year() operation on DateType, got {}", + "Can only run year() operation on temporal types, got {}", self.data_type() ))), } } pub fn dt_day_of_week(&self) -> DaftResult { - if !matches!(self.data_type(), DataType::Date) { - return Err(DaftError::ComputeError(format!( - "Can only run day_of_week() operation on DateType, got {}", + match self.data_type() { + DataType::Date => { + let downcasted = self.downcast::()?; + Ok(downcasted.day_of_week()?.into_series()) + } + DataType::Timestamp(..) => { + let ts_array = self.downcast::()?; + Ok(ts_array.date()?.day_of_week()?.into_series()) + } + _ => Err(DaftError::ComputeError(format!( + "Can only run dt_day_of_week() operation on temporal types, got {}", self.data_type() - ))); + ))), } - - let downcasted = self.downcast::()?; - Ok(downcasted.day_of_week()?.into_series()) } } diff --git a/src/daft-dsl/src/functions/temporal/date.rs b/src/daft-dsl/src/functions/temporal/date.rs new file mode 100644 index 0000000000..fec944df3c --- /dev/null +++ b/src/daft-dsl/src/functions/temporal/date.rs @@ -0,0 +1,47 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use crate::Expr; + +use super::super::FunctionEvaluator; + +pub(super) struct DateEvaluator {} + +impl FunctionEvaluator for DateEvaluator { + fn fn_name(&self) -> &'static str { + "date" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input] => match input.to_field(schema) { + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::Date)) + } + Ok(field) => Err(DaftError::TypeError(format!( + "Expected input to date to be temporal, got {}", + field.dtype + ))), + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input] => input.dt_date(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/temporal/mod.rs b/src/daft-dsl/src/functions/temporal/mod.rs index 0d5929c5e6..53a64d98e4 100644 --- a/src/daft-dsl/src/functions/temporal/mod.rs +++ b/src/daft-dsl/src/functions/temporal/mod.rs @@ -1,3 +1,4 @@ +mod date; mod day; mod day_of_week; mod month; @@ -6,7 +7,8 @@ mod year; use serde::{Deserialize, Serialize}; use crate::functions::temporal::{ - day::DayEvaluator, day_of_week::DayOfWeekEvaluator, month::MonthEvaluator, year::YearEvaluator, + date::DateEvaluator, day::DayEvaluator, day_of_week::DayOfWeekEvaluator, month::MonthEvaluator, + year::YearEvaluator, }; use crate::Expr; @@ -18,6 +20,7 @@ pub enum TemporalExpr { Month, Year, DayOfWeek, + Date, } impl TemporalExpr { @@ -29,10 +32,18 @@ impl TemporalExpr { Month => &MonthEvaluator {}, Year => &YearEvaluator {}, DayOfWeek => &DayOfWeekEvaluator {}, + Date => &DateEvaluator {}, } } } +pub fn date(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Temporal(TemporalExpr::Date), + inputs: vec![input.clone()], + } +} + pub fn day(input: &Expr) -> Expr { Expr::Function { func: super::FunctionExpr::Temporal(TemporalExpr::Day), diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 4ed62ff51e..b89e71f570 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -273,6 +273,11 @@ impl PyExpr { Ok(is_nan(&self.expr).into()) } + pub fn dt_date(&self) -> PyResult { + use functions::temporal::date; + Ok(date(&self.expr).into()) + } + pub fn dt_day(&self) -> PyResult { use functions::temporal::day; Ok(day(&self.expr).into()) diff --git a/tests/expressions/typing/conftest.py b/tests/expressions/typing/conftest.py index a8e7a0f51c..8831fea614 100644 --- a/tests/expressions/typing/conftest.py +++ b/tests/expressions/typing/conftest.py @@ -13,7 +13,7 @@ import pyarrow as pa import pytest -from daft.datatype import DataType +from daft.datatype import DataType, TimeUnit from daft.expressions import Expression, ExpressionsProjection from daft.series import Series from daft.table import Table @@ -34,6 +34,10 @@ (DataType.null(), pa.array([None, None, None], type=pa.null())), (DataType.binary(), pa.array([b"1", b"2", None], type=pa.binary())), (DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())), + ( + DataType.timestamp(TimeUnit.ms()), + pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp("ms")), + ), ] ALL_DATATYPES_BINARY_PAIRS = list(itertools.product(ALL_DTYPES, repeat=2)) diff --git a/tests/expressions/typing/test_dt.py b/tests/expressions/typing/test_dt.py index 2ba2066bac..5e14a4586e 100644 --- a/tests/expressions/typing/test_dt.py +++ b/tests/expressions/typing/test_dt.py @@ -13,6 +13,7 @@ pytest.param(lambda x: x.dt.month(), id="month"), pytest.param(lambda x: x.dt.year(), id="year"), pytest.param(lambda x: x.dt.day_of_week(), id="day_of_week"), + pytest.param(lambda x: x.dt.date(), id="date"), ], ) def test_dt_extraction_ops(unary_data_fixture, op):