Skip to content

Commit

Permalink
pr feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 committed Sep 20, 2024
1 parent 5449946 commit 938da5e
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 81 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ daft-dsl = {path = "../daft-dsl", default-features = false}
daft-image = {path = "../daft-image", default-features = false}
daft-io = {path = "../daft-io", default-features = false}
futures = {workspace = true}
paste = "1.0.15"
pyo3 = {workspace = true, optional = true}
tiktoken-rs = {workspace = true}
tokio = {workspace = true}
Expand Down
179 changes: 100 additions & 79 deletions src/daft-functions/src/temporal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,101 +16,92 @@ use serde::{Deserialize, Serialize};

#[cfg(feature = "python")]
pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_function(wrap_pyfunction_bound!(py_dt_date::dt_date, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_day::dt_day, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(
py_dt_day_of_week::dt_day_of_week,
parent
)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_hour::dt_hour, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_minute::dt_minute, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_month::dt_month, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_second::dt_second, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_time::dt_time, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_year::dt_year, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_date, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_day, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_day_of_week, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_hour, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_minute, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_month, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_second, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_time, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(py_dt_year, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(truncate::py_dt_truncate, parent)?)?;
Ok(())
}

macro_rules! impl_temporal {
($name_str:expr, $name:ident, $dt:ident, $py_name:ident, $dtype:ident) => {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct $name {}

#[typetag::serde]
impl ScalarUDF for $name {
fn as_any(&self) -> &dyn std::any::Any {
self
}
// pyo3 macro can't handle any expressions other than a 'literal', so we have to redundantly pass it in via $py_name
($name:ident, $dt:ident, $py_name:literal, $dtype:ident) => {
paste::paste! {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct $name;

#[typetag::serde]
impl ScalarUDF for $name {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &'static str {
$name_str
}
fn name(&self) -> &'static str {
stringify!([ < $name:snake:lower > ])
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[input] => match input.to_field(schema) {
Ok(field) if field.dtype.is_temporal() => {
Ok(Field::new(field.name, DataType::$dtype))
}
Ok(field) => Err(DaftError::TypeError(format!(
"Expected input to {} to be temporal, got {}",
self.name(),
field.dtype
fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[input] => match input.to_field(schema) {
Ok(field) if field.dtype.is_temporal() => {
Ok(Field::new(field.name, DataType::$dtype))
}
Ok(field) => Err(DaftError::TypeError(format!(
"Expected input to {} to be temporal, got {}",
self.name(),
field.dtype
))),
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}
}

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
match inputs {
[input] => input.$dt(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
match inputs {
[input] => input.$dt(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}
}
}

pub fn $dt(input: ExprRef) -> ExprRef {
ScalarFunction::new($name {}, vec![input]).into()
}
pub fn $dt(input: ExprRef) -> ExprRef {
ScalarFunction::new($name {}, vec![input]).into()
}

// We gotta throw these in a module to avoid name conflicts
#[cfg(feature = "python")]
mod $py_name {
use super::*;
#[pyfunction]
pub fn $dt(expr: PyExpr) -> PyResult<PyExpr> {
Ok(super::$dt(expr.into()).into())
#[pyo3(name = $py_name)]
#[cfg(feature = "python")]
pub fn [<py_ $dt>](expr: PyExpr) -> PyResult<PyExpr> {
Ok($dt(expr.into()).into())
}
}
};
}

impl_temporal!("date", Date, dt_date, py_dt_date, Date);
impl_temporal!("day", Day, dt_day, py_dt_day, UInt32);
impl_temporal!(
"day_of_week",
DayOfWeek,
dt_day_of_week,
py_dt_day_of_week,
UInt32
);
impl_temporal!("hour", Hour, dt_hour, py_dt_hour, UInt32);
impl_temporal!("minute", Minute, dt_minute, py_dt_minute, UInt32);
impl_temporal!("month", Month, dt_month, py_dt_month, UInt32);
impl_temporal!("second", Second, dt_second, py_dt_second, UInt32);
impl_temporal!("year", Year, dt_year, py_dt_year, Int32);
impl_temporal!(Date, dt_date, "dt_date", Date);
impl_temporal!(Day, dt_day, "dt_day", UInt32);
impl_temporal!(Hour, dt_hour, "dt_hour", UInt32);
impl_temporal!(DayOfWeek, dt_day_of_week, "dt_day_of_week", UInt32);
impl_temporal!(Minute, dt_minute, "dt_minute", UInt32);
impl_temporal!(Month, dt_month, "dt_month", UInt32);
impl_temporal!(Second, dt_second, "dt_second", UInt32);
impl_temporal!(Year, dt_year, "dt_year", Int32);

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Time {}
pub struct Time;

#[typetag::serde]
impl ScalarUDF for Time {
Expand Down Expand Up @@ -163,12 +154,42 @@ pub fn dt_time(input: ExprRef) -> ExprRef {
ScalarFunction::new(Time {}, vec![input]).into()
}

// We gotta throw these in a module to avoid name conflicts
#[cfg(feature = "python")]
mod py_dt_time {
use super::*;
#[pyfunction]
pub fn dt_time(expr: PyExpr) -> PyResult<PyExpr> {
Ok(super::dt_time(expr.into()).into())
#[pyfunction]
#[pyo3(name = "dt_time")]
pub fn py_dt_time(expr: PyExpr) -> PyResult<PyExpr> {
Ok(dt_time(expr.into()).into())
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use super::truncate::Truncate;

#[test]
fn test_fn_name() {
use super::*;
let cases: Vec<(Arc<dyn ScalarUDF>, &str)> = vec![
(Arc::new(Date), "date"),
(Arc::new(Day), "day"),
(Arc::new(Hour), "hour"),
(Arc::new(DayOfWeek), "day_of_week"),
(Arc::new(Minute), "minute"),
(Arc::new(Month), "month"),
(Arc::new(Second), "second"),
(Arc::new(Time), "time"),
(Arc::new(Year), "year"),
(
Arc::new(Truncate {
interval: "".into(),
}),
"truncate",
),
];

for (f, name) in cases {
assert_eq!(f.name(), name);
}
}
}
4 changes: 2 additions & 2 deletions src/daft-functions/src/temporal/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Truncate {
interval: String,
pub(super) interval: String,
}

#[typetag::serde]
Expand All @@ -18,7 +18,7 @@ impl ScalarUDF for Truncate {
}

fn name(&self) -> &'static str {
stringify!($fn_name)
"truncate"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
Expand Down

0 comments on commit 938da5e

Please sign in to comment.