Skip to content

Commit

Permalink
Add .dt.date() expression
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Sep 20, 2023
1 parent e1d0658 commit 593fc05
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 12 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ class PyExpr:
def __setstate__(self, state: Any) -> None: ...
def __getstate__(self) -> Any: ...
def is_nan(self) -> PyExpr: ...
def dt_date(self) -> PyExpr: ...
def dt_day(self) -> PyExpr: ...
def dt_month(self) -> PyExpr: ...
def dt_year(self) -> PyExpr: ...
Expand Down Expand Up @@ -608,6 +609,7 @@ class PySeries:
def utf8_contains(self, pattern: PySeries) -> PySeries: ...
def utf8_length(self) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def dt_date(self) -> PySeries: ...
def dt_day(self) -> PySeries: ...
def dt_month(self) -> PySeries: ...
def dt_year(self) -> PySeries: ...
Expand Down
11 changes: 11 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,17 @@ def is_nan(self) -> Expression:


class ExpressionDatetimeNamespace(ExpressionNamespace):
def date(self) -> Expression:
"""Retrieves the date for a datetime column
Example:
>>> col("x").dt.date()
Returns:
Expression: a Date expression
"""
return Expression._from_pyexpr(self._expr.dt_date())

def day(self) -> Expression:
"""Retrieves the day for a datetime column
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ def length(self) -> Series:


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
return Series._from_pyseries(self._series.dt_date())

def day(self) -> Series:
return Series._from_pyseries(self._series.dt_day())

Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Example: ``e.dt.day()``
daft.expressions.expressions.ExpressionDatetimeNamespace.month
daft.expressions.expressions.ExpressionDatetimeNamespace.year
daft.expressions.expressions.ExpressionDatetimeNamespace.day_of_week
daft.expressions.expressions.ExpressionDatetimeNamespace.date

.. _api-expressions-urls:

Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ impl PySeries {
Ok(self.series.is_nan()?.into())
}

pub fn dt_date(&self) -> PyResult<Self> {
Ok(self.series.dt_date()?.into())
}

pub fn dt_day(&self) -> PyResult<Self> {
Ok(self.series.dt_day()?.into())
}
Expand Down
39 changes: 29 additions & 10 deletions src/daft-core/src/series/ops/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@ use crate::{
use common_error::{DaftError, DaftResult};

impl Series {
pub fn dt_date(&self) -> DaftResult<Self> {
match self.data_type() {
DataType::Date => Ok(self.clone()),
DataType::Timestamp(..) => {
let ts_array = self.downcast::<TimestampArray>()?;
Ok(ts_array.date()?.into_series())
}
_ => Err(DaftError::ComputeError(format!(
"Can only run date() operation on temporal types, got {}",
self.data_type()
))),
}
}

pub fn dt_day(&self) -> DaftResult<Self> {
match self.data_type() {
DataType::Date => {
Expand All @@ -18,7 +32,7 @@ impl Series {
Ok(ts_array.date()?.day()?.into_series())
}
_ => Err(DaftError::ComputeError(format!(
"Can only run day() operation on DateType, got {}",
"Can only run day() operation on temporal types, got {}",
self.data_type()
))),
}
Expand All @@ -35,7 +49,7 @@ impl Series {
Ok(ts_array.date()?.month()?.into_series())
}
_ => Err(DaftError::ComputeError(format!(
"Can only run month() operation on DateType, got {}",
"Can only run month() operation on temporal types, got {}",
self.data_type()
))),
}
Expand All @@ -52,21 +66,26 @@ impl Series {
Ok(ts_array.date()?.year()?.into_series())
}
_ => Err(DaftError::ComputeError(format!(
"Can only run year() operation on DateType, got {}",
"Can only run year() operation on temporal types, got {}",
self.data_type()
))),
}
}

pub fn dt_day_of_week(&self) -> DaftResult<Self> {
if !matches!(self.data_type(), DataType::Date) {
return Err(DaftError::ComputeError(format!(
"Can only run day_of_week() operation on DateType, got {}",
match self.data_type() {
DataType::Date => {
let downcasted = self.downcast::<DateArray>()?;
Ok(downcasted.day_of_week()?.into_series())
}
DataType::Timestamp(..) => {
let ts_array = self.downcast::<TimestampArray>()?;
Ok(ts_array.date()?.day_of_week()?.into_series())
}
_ => Err(DaftError::ComputeError(format!(
"Can only run dt_day_of_week() operation on temporal types, got {}",
self.data_type()
)));
))),
}

let downcasted = self.downcast::<DateArray>()?;
Ok(downcasted.day_of_week()?.into_series())
}
}
47 changes: 47 additions & 0 deletions src/daft-dsl/src/functions/temporal/date.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use common_error::{DaftError, DaftResult};
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::Expr;

use super::super::FunctionEvaluator;

pub(super) struct DateEvaluator {}

impl FunctionEvaluator for DateEvaluator {
fn fn_name(&self) -> &'static str {
"date"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[input] => match input.to_field(schema) {
Ok(field) if field.dtype.is_temporal() => {
Ok(Field::new(field.name, DataType::Date))
}
Ok(field) => Err(DaftError::TypeError(format!(
"Expected input to date to be temporal, got {}",
field.dtype
))),
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
match inputs {
[input] => input.dt_date(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}
}
13 changes: 12 additions & 1 deletion src/daft-dsl/src/functions/temporal/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod date;
mod day;
mod day_of_week;
mod month;
Expand All @@ -6,7 +7,8 @@ mod year;
use serde::{Deserialize, Serialize};

use crate::functions::temporal::{
day::DayEvaluator, day_of_week::DayOfWeekEvaluator, month::MonthEvaluator, year::YearEvaluator,
date::DateEvaluator, day::DayEvaluator, day_of_week::DayOfWeekEvaluator, month::MonthEvaluator,
year::YearEvaluator,
};
use crate::Expr;

Expand All @@ -18,6 +20,7 @@ pub enum TemporalExpr {
Month,
Year,
DayOfWeek,
Date,
}

impl TemporalExpr {
Expand All @@ -29,10 +32,18 @@ impl TemporalExpr {
Month => &MonthEvaluator {},
Year => &YearEvaluator {},
DayOfWeek => &DayOfWeekEvaluator {},
Date => &DateEvaluator {},
}
}
}

pub fn date(input: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Temporal(TemporalExpr::Date),
inputs: vec![input.clone()],
}
}

pub fn day(input: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Temporal(TemporalExpr::Day),
Expand Down
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ impl PyExpr {
Ok(is_nan(&self.expr).into())
}

pub fn dt_date(&self) -> PyResult<Self> {
use functions::temporal::date;
Ok(date(&self.expr).into())
}

pub fn dt_day(&self) -> PyResult<Self> {
use functions::temporal::day;
Ok(day(&self.expr).into())
Expand Down
6 changes: 5 additions & 1 deletion tests/expressions/typing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pyarrow as pa
import pytest

from daft.datatype import DataType
from daft.datatype import DataType, TimeUnit
from daft.expressions import Expression, ExpressionsProjection
from daft.series import Series
from daft.table import Table
Expand All @@ -34,6 +34,10 @@
(DataType.null(), pa.array([None, None, None], type=pa.null())),
(DataType.binary(), pa.array([b"1", b"2", None], type=pa.binary())),
(DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())),
(
DataType.timestamp(TimeUnit.ms()),
pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp("ms")),
),
]

ALL_DATATYPES_BINARY_PAIRS = list(itertools.product(ALL_DTYPES, repeat=2))
Expand Down
1 change: 1 addition & 0 deletions tests/expressions/typing/test_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
pytest.param(lambda x: x.dt.month(), id="month"),
pytest.param(lambda x: x.dt.year(), id="year"),
pytest.param(lambda x: x.dt.day_of_week(), id="day_of_week"),
pytest.param(lambda x: x.dt.date(), id="date"),
],
)
def test_dt_extraction_ops(unary_data_fixture, op):
Expand Down

0 comments on commit 593fc05

Please sign in to comment.