From 209a4e022c7a0fdb6e73faca762d6eacee7e488e Mon Sep 17 00:00:00 2001 From: GuyPozner Date: Fri, 31 May 2024 02:22:41 +0300 Subject: [PATCH] [FEAT] Expression between (#2301) Closes #1903 It still WIP, I just want to get some feedback to know if I am on the right path, I have a feeling I might have overdone it with defining Between as its own trait. Edit: Can someone point me to a rust test that covers the DaftError cases? --- daft/daft.pyi | 1 + daft/expressions/expressions.py | 16 +++ daft/utils.py | 4 + docs/source/api_docs/expressions.rst | 1 + src/daft-core/src/array/ops/between.rs | 73 +++++++++++ src/daft-core/src/array/ops/mod.rs | 6 + src/daft-core/src/series/ops/between.rs | 66 ++++++++++ src/daft-core/src/series/ops/mod.rs | 66 ++++++++++ src/daft-dsl/src/expr.rs | 32 +++++ src/daft-dsl/src/optimization.rs | 1 + src/daft-dsl/src/python.rs | 8 ++ src/daft-plan/src/builder.rs | 1 + src/daft-plan/src/logical_ops/project.rs | 20 +++ src/daft-plan/src/partitioning.rs | 6 + src/daft-table/src/lib.rs | 3 + tests/table/test_between.py | 154 +++++++++++++++++++++++ 16 files changed, 458 insertions(+) create mode 100644 src/daft-core/src/array/ops/between.rs create mode 100644 src/daft-core/src/series/ops/between.rs create mode 100644 tests/table/test_between.py diff --git a/daft/daft.pyi b/daft/daft.pyi index 73dcd3b20d..89ae597598 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -968,6 +968,7 @@ class PyExpr: def not_null(self) -> PyExpr: ... def fill_null(self, fill_value: PyExpr) -> PyExpr: ... def is_in(self, other: PyExpr) -> PyExpr: ... + def between(self, lower: PyExpr, upper: PyExpr) -> PyExpr: ... def name(self) -> str: ... def to_field(self, schema: PySchema) -> PyField: ... def to_sql(self) -> str: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 477a2d6059..de786be747 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -670,6 +670,22 @@ def is_in(self, other: Any) -> Expression: expr = self._expr.is_in(other._expr) return Expression._from_pyexpr(expr) + def between(self, lower: Any, upper: Any) -> Expression: + """Checks if values in the Expression are between lower and upper, inclusive. + + Example: + >>> # [1, 2, 3, 4] -> [True, True, False, False] + >>> col("x").between(1, 2) + + Returns: + Expression: Boolean Expression indicating whether values are between lower and upper, inclusive. + """ + lower = Expression._to_expression(lower) + upper = Expression._to_expression(upper) + + expr = self._expr.between(lower._expr, upper._expr) + return Expression._from_pyexpr(expr) + def name(self) -> builtins.str: return self._expr.name() diff --git a/daft/utils.py b/daft/utils.py index abcc2f911e..d61456409d 100644 --- a/daft/utils.py +++ b/daft/utils.py @@ -99,6 +99,10 @@ def python_list_membership_check( return [elem in right_pylist for elem in left_pylist] +def python_list_between_check(value_pylist: list, lower_pylist: list, upper_pylist: list) -> list: + return [value <= upper and value >= lower for value, lower, upper in zip(value_pylist, lower_pylist, upper_pylist)] + + def map_operator_arrow_semantics( operator: Callable[[Any, Any], Any], left_pylist: list, diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index 011bd56eb8..5b0c30ccfa 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -84,6 +84,7 @@ Logical Expression.__ne__ Expression.__gt__ Expression.__ge__ + Expression.between Expression.is_in .. _api=aggregation-expression: diff --git a/src/daft-core/src/array/ops/between.rs b/src/daft-core/src/array/ops/between.rs new file mode 100644 index 0000000000..e9903d8211 --- /dev/null +++ b/src/daft-core/src/array/ops/between.rs @@ -0,0 +1,73 @@ +use super::{DaftBetween, DaftCompare, DaftLogical}; +use crate::{ + array::DataArray, + datatypes::{BooleanArray, DaftNumericType}, +}; +use common_error::{DaftError, DaftResult}; + +impl DaftBetween<&DataArray, &DataArray> for DataArray +where + T: DaftNumericType, +{ + type Output = DaftResult; + + fn between(&self, lower: &DataArray, upper: &DataArray) -> Self::Output { + let are_two_equal_and_single_one = |v_size, l_size, u_size: usize| { + [v_size, l_size, u_size] + .iter() + .filter(|&&size| size != 1) + .collect::>() + .len() + == 1 + }; + match (self.len(), lower.len(), upper.len()) { + (v_size, l_size, u_size) if (v_size == l_size && v_size == u_size) || (l_size == 1 && u_size == 1) || (are_two_equal_and_single_one(v_size, l_size, u_size)) => { + let gte_res = self.gte(lower)?; + let lte_res = self.lte(upper)?; + gte_res.and(<e_res) + }, + (v_size, l_size, u_size) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {v_size} vs {}: {l_size} vs {}: {u_size}", + self.name(), + lower.name(), + upper.name() + ))), + } + } +} + +#[cfg(test)] +mod tests { + use crate::{array::ops::DaftBetween, datatypes::Int64Array}; + use common_error::DaftResult; + + #[test] + fn test_between_two_arrays_of_same_size() -> DaftResult<()> { + let value = Int64Array::arange("value", 1, 4, 1)?; + let lower = Int64Array::arange("lower", 0, 6, 2)?; + let upper = Int64Array::arange("upper", -2, 8, 4)?; + let result: Vec<_> = value.between(&lower, &upper)?.into_iter().collect(); + assert_eq!(result[..], [Some(false), Some(true), Some(false)]); + Ok(()) + } + + #[test] + fn test_between_array_with_multiple_items_and_array_with_single_item() -> DaftResult<()> { + let value = Int64Array::arange("value", 1, 4, 1)?; + let lower = Int64Array::arange("lower", 1, 4, 1)?; + let upper = Int64Array::arange("upper", 1, 2, 1)?; + let result: Vec<_> = value.between(&lower, &upper)?.into_iter().collect(); + assert_eq!(result[..], [Some(true), Some(false), Some(false)]); + Ok(()) + } + + #[test] + fn test_between_two_arrays_with_single_item() -> DaftResult<()> { + let value = Int64Array::arange("value", 1, 4, 1)?; + let lower = Int64Array::arange("lower", 1, 2, 1)?; + let upper = Int64Array::arange("upper", 1, 2, 1)?; + let result: Vec<_> = value.between(&lower, &upper)?.into_iter().collect(); + assert_eq!(result[..], [Some(true), Some(false), Some(false)]); + Ok(()) + } +} diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 051122d6b3..56637a25f5 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -5,6 +5,7 @@ mod arange; mod arithmetic; pub mod arrow2; pub mod as_arrow; +mod between; pub(crate) mod broadcast; pub(crate) mod cast; mod ceil; @@ -99,6 +100,11 @@ pub trait DaftIsIn { fn is_in(&self, rhs: Rhs) -> Self::Output; } +pub trait DaftBetween { + type Output; + fn between(&self, lower: Lower, upper: Upper) -> Self::Output; +} + pub trait DaftIsNull { type Output; fn is_null(&self) -> Self::Output; diff --git a/src/daft-core/src/series/ops/between.rs b/src/daft-core/src/series/ops/between.rs new file mode 100644 index 0000000000..9722a27f7d --- /dev/null +++ b/src/daft-core/src/series/ops/between.rs @@ -0,0 +1,66 @@ +use common_error::DaftResult; + +use crate::{ + array::ops::DaftBetween, datatypes::BooleanArray, with_match_numeric_daft_types, DataType, + IntoSeries, Series, +}; + +#[cfg(feature = "python")] +use crate::series::ops::py_between_op_utilfn; + +impl Series { + pub fn between(&self, lower: &Series, upper: &Series) -> DaftResult { + let (_output_type, _intermediate, lower_comp_type) = + self.data_type().comparison_op(lower.data_type())?; + let (_output_type, _intermediate, upper_comp_type) = + self.data_type().comparison_op(upper.data_type())?; + let (output_type, intermediate, comp_type) = + lower_comp_type.comparison_op(&upper_comp_type)?; + let (it_value, it_lower, it_upper) = if let Some(ref it) = intermediate { + (self.cast(it)?, lower.cast(it)?, upper.cast(it)?) + } else { + (self.clone(), lower.clone(), upper.clone()) + }; + if let DataType::Boolean = output_type { + match comp_type { + #[cfg(feature = "python")] + DataType::Python => Ok(py_between_op_utilfn(self, upper, lower)? + .downcast::()? + .clone() + .into_series()), + DataType::Null => Ok(Series::full_null( + self.name(), + &DataType::Boolean, + self.len(), + )), + _ => with_match_numeric_daft_types!(comp_type, |$T| { + let casted_value = it_value.cast(&comp_type)?; + let casted_lower = it_lower.cast(&comp_type)?; + let casted_upper = it_upper.cast(&comp_type)?; + let value = casted_value.downcast::<<$T as DaftDataType>::ArrayType>()?; + let lower = casted_lower.downcast::<<$T as DaftDataType>::ArrayType>()?; + let upper = casted_upper.downcast::<<$T as DaftDataType>::ArrayType>()?; + Ok(value.between(lower, upper)?.into_series()) + }), + } + } else { + unreachable!() + } + } +} + +#[cfg(test)] +mod tests { + use crate::{DataType, Series}; + + use common_error::DaftResult; + + #[test] + fn test_between_all_null() -> DaftResult<()> { + let value = Series::full_null("value", &DataType::Null, 2); + let lower = Series::full_null("lower", &DataType::Int64, 2); + let upper = Series::full_null("upper", &DataType::Int64, 2); + _ = value.between(&lower, &upper)?; + Ok(()) + } +} diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index d3d8284b00..96f213249b 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -6,6 +6,7 @@ use super::Series; pub mod abs; pub mod agg; pub mod arithmetic; +pub mod between; pub mod broadcast; pub mod cast; pub mod ceil; @@ -126,3 +127,68 @@ pub(super) fn py_membership_op_utilfn(lhs: &Series, rhs: &Series) -> DaftResult< Ok(result_series) } + +#[cfg(feature = "python")] +pub(super) fn py_between_op_utilfn( + value: &Series, + lower: &Series, + upper: &Series, +) -> DaftResult { + use crate::python::PySeries; + use crate::DataType; + use pyo3::prelude::*; + + let value_casted = value.cast(&DataType::Python)?; + let lower_casted = lower.cast(&DataType::Python)?; + let upper_casted = upper.cast(&DataType::Python)?; + + let (value_casted, lower_casted, upper_casted) = + match (value_casted.len(), lower_casted.len(), upper_casted.len()) { + (a, b, c) if a == b && b == c => (value_casted, lower_casted, upper_casted), + (1, a, b) if a == b => (value_casted.broadcast(a)?, lower_casted, upper_casted), + (a, 1, b) if a == b => (value_casted, lower_casted.broadcast(a)?, upper_casted), + (a, b, 1) if a == b => (value_casted, lower_casted, upper_casted.broadcast(a)?), + (a, 1, 1) => ( + value_casted, + lower_casted.broadcast(a)?, + upper_casted.broadcast(a)?, + ), + (1, a, 1) => ( + value_casted.broadcast(a)?, + lower_casted, + upper_casted.broadcast(a)?, + ), + (1, 1, a) => ( + value_casted.broadcast(a)?, + lower_casted.broadcast(a)?, + upper_casted, + ), + (a, b, c) => { + panic!("Cannot apply operation on arrays of different lengths: {a} vs {b} vs {c}") + } + }; + + let value_pylist = PySeries::from(value_casted.clone()).to_pylist()?; + let lower_pylist = PySeries::from(lower_casted.clone()).to_pylist()?; + let upper_pylist = PySeries::from(upper_casted.clone()).to_pylist()?; + + let result_series: Series = Python::with_gil(|py| -> PyResult { + let result_pylist = PyModule::import(py, pyo3::intern!(py, "daft.utils"))? + .getattr(pyo3::intern!(py, "python_list_between_check"))? + .call1((value_pylist, lower_pylist, upper_pylist))?; + + PyModule::import(py, pyo3::intern!(py, "daft.series"))? + .getattr(pyo3::intern!(py, "Series"))? + .getattr(pyo3::intern!(py, "from_pylist"))? + .call1(( + result_pylist, + value_casted.name(), + pyo3::intern!(py, "disallow"), + ))? + .getattr(pyo3::intern!(py, "_series"))? + .extract() + })? + .into(); + + Ok(result_series) +} diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index a1062b0919..f21062747d 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -49,6 +49,7 @@ pub enum Expr { NotNull(ExprRef), FillNull(ExprRef, ExprRef), IsIn(ExprRef, ExprRef), + Between(ExprRef, ExprRef, ExprRef), Literal(lit::LiteralValue), IfElse { if_true: ExprRef, @@ -464,6 +465,10 @@ impl Expr { Expr::IsIn(self, items).into() } + pub fn between(self: ExprRef, lower: ExprRef, upper: ExprRef) -> ExprRef { + Expr::Between(self, lower, upper).into() + } + pub fn eq(self: ExprRef, other: ExprRef) -> ExprRef { binary_op(Operator::Eq, self, other) } @@ -533,6 +538,12 @@ impl Expr { let items_id = items.semantic_id(schema); FieldID::new(format!("{child_id}.is_in({items_id})")) } + Between(expr, lower, upper) => { + let child_id = expr.semantic_id(schema); + let lower_id = lower.semantic_id(schema); + let upper_id = upper.semantic_id(schema); + FieldID::new(format!("{child_id}.between({lower_id},{upper_id})")) + } Function { func, inputs } => function_semantic_id(func, inputs, schema), BinaryOp { op, left, right } => { let left_id = left.semantic_id(schema); @@ -579,6 +590,7 @@ impl Expr { vec![left.clone(), right.clone()] } IsIn(expr, items) => vec![expr.clone(), items.clone()], + Between(expr, lower, upper) => vec![expr.clone(), lower.clone(), upper.clone()], IfElse { if_true, if_false, @@ -617,6 +629,11 @@ impl Expr { children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), ), + Between(..) => Between( + children.first().expect("Should have 1 child").clone(), + children.get(1).expect("Should have 2 child").clone(), + children.get(2).expect("Should have 3 child").clone(), + ), FillNull(..) => FillNull( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), @@ -671,6 +688,18 @@ impl Expr { left_field.dtype.membership_op(&right_field.dtype)?; Ok(Field::new(left_field.name.as_str(), result_type)) } + Between(value, lower, upper) => { + let value_field = value.to_field(schema)?; + let lower_field = lower.to_field(schema)?; + let upper_field = upper.to_field(schema)?; + let (lower_result_type, _intermediate, _comp_type) = + value_field.dtype.membership_op(&lower_field.dtype)?; + let (upper_result_type, _intermediate, _comp_type) = + value_field.dtype.membership_op(&upper_field.dtype)?; + let (result_type, _intermediate, _comp_type) = + lower_result_type.membership_op(&upper_result_type)?; + Ok(Field::new(value_field.name.as_str(), result_type)) + } Literal(value) => Ok(Field::new("literal", value.get_type())), Function { func, inputs } => func.to_field(inputs.as_slice(), schema, func), BinaryOp { op, left, right } => { @@ -763,6 +792,7 @@ impl Expr { NotNull(expr) => expr.name(), FillNull(expr, ..) => expr.name(), IsIn(expr, ..) => expr.name(), + Between(expr, ..) => expr.name(), Literal(..) => "literal", Function { func, inputs } => match func { FunctionExpr::Struct(StructExpr::Get(name)) => name, @@ -853,6 +883,7 @@ impl Expr { Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) + | Expr::Between(..) | Expr::Function { .. } | Expr::FillNull(..) => Err(io::Error::new( io::ErrorKind::Other, @@ -894,6 +925,7 @@ impl Display for Expr { NotNull(expr) => write!(f, "not_null({expr})"), FillNull(expr, fill_value) => write!(f, "fill_null({expr}, {fill_value})"), IsIn(expr, items) => write!(f, "{expr} in {items}"), + Between(expr, lower, upper) => write!(f, "{expr} in [{lower},{upper}]"), Literal(val) => write!(f, "lit({val})"), Function { func, inputs } => function_display(f, func, inputs), IfElse { diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index e560c2843e..cbe76c167e 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -32,6 +32,7 @@ pub fn requires_computation(e: &Expr) -> bool { | Expr::NotNull(..) | Expr::FillNull(..) | Expr::IsIn { .. } + | Expr::Between { .. } | Expr::IfElse { .. } => true, } } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 860423bd4f..8777a3ceb8 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -440,6 +440,14 @@ impl PyExpr { Ok(self.expr.clone().is_in(other.expr.clone()).into()) } + pub fn between(&self, lower: &Self, upper: &Self) -> PyResult { + Ok(self + .expr + .clone() + .between(lower.expr.clone(), upper.expr.clone()) + .into()) + } + pub fn name(&self) -> PyResult<&str> { Ok(self.expr.name()) } diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 0565b88446..1ac0156250 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -60,6 +60,7 @@ fn check_for_agg(expr: &ExprRef) -> bool { BinaryOp { left, right, .. } => check_for_agg(left) || check_for_agg(right), Function { inputs, .. } => inputs.iter().any(check_for_agg), IsIn(l, r) | FillNull(l, r) => check_for_agg(l) || check_for_agg(r), + Between(v, l, u) => check_for_agg(v) || check_for_agg(l) || check_for_agg(u), IfElse { if_true, if_false, diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 867ce24dbf..86905a41a9 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -275,6 +275,26 @@ fn replace_column_with_semantic_id( ) } } + Expr::Between(child, lower, upper) => { + let child = + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema); + let lower = + replace_column_with_semantic_id(lower.clone(), subexprs_to_replace, schema); + let upper = + replace_column_with_semantic_id(upper.clone(), subexprs_to_replace, schema); + if child.is_no() && lower.is_no() && upper.is_no() { + Transformed::No(e) + } else { + Transformed::Yes( + Expr::Between( + child.unwrap().clone(), + lower.unwrap().clone(), + upper.unwrap().clone(), + ) + .into(), + ) + } + } Expr::BinaryOp { op, left, right } => { let left = replace_column_with_semantic_id(left.clone(), subexprs_to_replace, schema); diff --git a/src/daft-plan/src/partitioning.rs b/src/daft-plan/src/partitioning.rs index 05a7163dbd..41b9f458f6 100644 --- a/src/daft-plan/src/partitioning.rs +++ b/src/daft-plan/src/partitioning.rs @@ -278,6 +278,12 @@ fn translate_clustering_spec_expr( let newitems = translate_clustering_spec_expr(items, old_colname_to_new_colname)?; Ok(newchild.is_in(newitems)) } + Expr::Between(child, lower, upper) => { + let newchild = translate_clustering_spec_expr(child, old_colname_to_new_colname)?; + let newlower = translate_clustering_spec_expr(lower, old_colname_to_new_colname)?; + let newupper = translate_clustering_spec_expr(upper, old_colname_to_new_colname)?; + Ok(newchild.between(newlower, newupper)) + } Expr::IfElse { if_true, if_false, diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 97a50625ee..8e012a4be3 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -365,6 +365,9 @@ impl Table { IsIn(child, items) => self .eval_expression(child)? .is_in(&self.eval_expression(items)?), + Between(child, lower, upper) => self + .eval_expression(child)? + .between(&self.eval_expression(lower)?, &self.eval_expression(upper)?), BinaryOp { op, left, right } => { let lhs = self.eval_expression(left)?; let rhs = self.eval_expression(right)?; diff --git a/tests/table/test_between.py b/tests/table/test_between.py new file mode 100644 index 0000000000..8848c0693b --- /dev/null +++ b/tests/table/test_between.py @@ -0,0 +1,154 @@ +import datetime + +import pytest + +from daft import col +from daft.table import MicroPartition + + +@pytest.mark.parametrize( + "value,lower,upper,expected", + [ + pytest.param([1, 2, 3, 4], 1, 2, [True, True, False, False], id="IntIntInt"), + pytest.param([1, 2, 3, 4], 1.0, 2.0, [True, True, False, False], id="IntFloatFloat"), + pytest.param([1, 2, 3, 4], 1, 2.0, [True, True, False, False], id="IntIntFloat"), + pytest.param([1.0, 2.0, 3.0, 4.0], 1.0, 2.0, [True, True, False, False], id="FloatFloatFloat"), + pytest.param([1.0, 2.0, 3.0, 4.0], 1, 2, [True, True, False, False], id="FloatIntInt"), + pytest.param([1.0, 2.0, 3.0, 4.0], 1, 2.0, [True, True, False, False], id="FloatIntFloat"), + pytest.param([1.0, 2.0, 3.0, 4.0], None, 1, [None, None, None, None], id="FloatNullInt"), + pytest.param([1.0, 2.0, 3.0, 4.0], 1, None, [None, None, None, None], id="FloatIntNull"), + pytest.param([1.0, 2.0, 3.0, 4.0], None, None, [None, None, None, None], id="FloatNullNull"), + pytest.param([None, None, None, None], None, None, [None, None, None, None], id="NullNullNull"), + pytest.param([None, None, None, None], 1, 1, [None, None, None, None], id="NullIntInt"), + pytest.param( + [datetime.datetime(2023, 1, 1), datetime.datetime(2022, 1, 1)], + datetime.datetime(2022, 12, 30), + datetime.datetime(2023, 1, 2), + [True, False], + id="Datetime", + ), + ], +) +def test_table_expr_between_scalars(value, lower, upper, expected) -> None: + daft_table = MicroPartition.from_pydict({"value": value}) + daft_table = daft_table.eval_expression_list([col("value").between(lower, upper)]) + pydict = daft_table.to_pydict() + assert pydict["value"] == expected + + +@pytest.mark.parametrize( + "value,lower,upper,expected", + [ + pytest.param([1, 2, 3, 4], [1, 1, 1, 1], [2, 2, 2, 2], [True, True, False, False], id="IntIntInt"), + pytest.param( + [1, 2, 3, 4], [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [True, True, False, False], id="IntFloatFloat" + ), + pytest.param([1, 2, 3, 4], [1, 1, 1, 1], [2.0, 2.0, 2.0, 2.0], [True, True, False, False], id="IntIntFloat"), + pytest.param( + [None, None, None, None], + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0], + [None, None, None, None], + id="NullFloatFloat", + ), + pytest.param( + [None, None, None, None], + [None, None, None, None], + [2.0, 2.0, 2.0, 2.0], + [None, None, None, None], + id="NullNullFloat", + ), + pytest.param( + [None, None, None, None], + [2.0, 2.0, 2.0, 2.0], + [None, None, None, None], + [None, None, None, None], + id="NullFloatNull", + ), + pytest.param( + [None, None, None, None], + [None, None, None, None], + [None, None, None, None], + [None, None, None, None], + id="NullNullNull", + ), + pytest.param( + [1.0, 2.0, 3.0, 4.0], + [1.0, 1.0, 1.0, 1.0], + [2.0, 2.0, 2.0, 2.0], + [True, True, False, False], + id="FloatFloatFloat", + ), + pytest.param([1.0, 2.0, 3.0, 4.0], [1, 1, 1, 1], [2, 2, 2, 2], [True, True, False, False], id="FloatIntInt"), + pytest.param( + [1.0, 2.0, 3.0, 4.0], [1, 1, 1, 1], [2.0, 2.0, 2.0, 2.0], [True, True, False, False], id="FloatIntFloat" + ), + pytest.param( + [datetime.datetime(2023, 1, 1), datetime.datetime(2022, 1, 1)], + [datetime.datetime(2022, 12, 30), datetime.datetime(2022, 12, 30)], + [datetime.datetime(2023, 1, 2), datetime.datetime(2023, 1, 2)], + [True, False], + id="Datetime", + ), + ], +) +def test_between_columns(value, lower, upper, expected) -> None: + table = {"value": value, "lower": lower, "upper": upper} + daft_table = MicroPartition.from_pydict(table) + daft_table = daft_table.eval_expression_list([col("value").between(col("lower"), col("upper"))]) + pydict = daft_table.to_pydict() + assert pydict["value"] == expected + + +@pytest.mark.parametrize( + "value,lower,upper", + [ + pytest.param(["str1", "str2"], 1, 2, id="StrIntInt"), + pytest.param([1, 2], "str", 1, id="IntStrInt"), + ], +) +def test_between_between_different_types(value, lower, upper) -> None: + daft_table = MicroPartition.from_pydict({"a": value}) + with pytest.raises(ValueError): + daft_table = daft_table.eval_expression_list([col("a").between(lower, upper)]) + + +def test_between_bad_input() -> None: + daft_table = MicroPartition.from_pydict({"a": [1, 2, 3]}) + with pytest.raises(TypeError): + daft_table = daft_table.eval_expression_list([col("a").between([1, 2, 3], 1)]) + + +@pytest.mark.parametrize( + "value,lower,upper,expected", + [ + pytest.param([1, 2, 3, 4], [1, 2, 3, 4], 2, [True, True, False, False], id="IntIntInt"), + pytest.param([1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0], 2.0, [True, True, False, False], id="IntFloatFloat"), + pytest.param([1, 2, 3, 4], [1, 2, 3, 4], 2.0, [True, True, False, False], id="IntIntFloat"), + pytest.param([1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], 2.0, [True, True, False, False], id="FloatFloatFloat"), + pytest.param([1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0], 2, [True, True, False, False], id="FloatIntInt"), + pytest.param([1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], 2.0, [True, True, False, False], id="FloatIntFloat"), + pytest.param([1.0, 2.0, 3.0, 4.0], [None, None, None, None], 1, [None, None, None, None], id="FloatNullInt"), + pytest.param([1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], None, [None, None, None, None], id="FloatIntNull"), + pytest.param( + [1.0, 2.0, 3.0, 4.0], [None, None, None, None], None, [None, None, None, None], id="FloatNullNull" + ), + pytest.param( + [None, None, None, None], [None, None, None, None], None, [None, None, None, None], id="NullNullNull" + ), + pytest.param([None, None, None, None], [1, 2, 3, 4], 1, [None, None, None, None], id="NullIntInt"), + pytest.param( + [datetime.datetime(2023, 1, 1), datetime.datetime(2022, 1, 1)], + [datetime.datetime(2022, 12, 30), datetime.datetime(2022, 12, 30)], + datetime.datetime(2023, 1, 2), + [True, False], + id="Datetime", + ), + ], +) +def test_table_expr_between_col_and_scalar(value, lower, upper, expected) -> None: + table = {"value": value, "lower": lower} + daft_table = MicroPartition.from_pydict(table) + daft_table = daft_table.eval_expression_list([col("value").between(col("lower"), upper)]) + pydict = daft_table.to_pydict() + assert pydict["value"] == expected