Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Disable Numeric and String comparison #2019

Merged
merged 4 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ impl DataType {
use DataType::*;
match (self, other) {
(s, o) if s == o => Ok((Boolean, None, s.to_physical())),
(Utf8, o) | (o, Utf8) if o.is_numeric() => Err(DaftError::TypeError(format!(
"Cannot perform comparison on Utf8 and numeric type.\ntypes: {}, {}",
self, other
))),
(s, o) if s.is_physical() && o.is_physical() => {
Ok((Boolean, None, try_physical_supertype(s, o)?))
}
Expand Down
5 changes: 1 addition & 4 deletions src/daft-core/src/series/ops/is_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ impl Series {
}

let (output_type, intermediate, comp_type) =
match self.data_type().membership_op(items.data_type()) {
Ok(types) => types,
Err(_) => return default(self.name(), self.len()),
};
self.data_type().membership_op(items.data_type())?;

let (lhs, rhs) = if let Some(ref it) = intermediate {
(self.cast(it)?, items.cast(it)?)
Expand Down
8 changes: 7 additions & 1 deletion src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,13 @@ impl Expr {
}
IsNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)),
NotNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)),
IsIn(expr, ..) => Ok(Field::new(expr.name()?, DataType::Boolean)),
IsIn(left, right) => {
let left_field = left.to_field(schema)?;
let right_field = right.to_field(schema)?;
let (result_type, _intermediate, _comp_type) =
left_field.dtype.membership_op(&right_field.dtype)?;
Ok(Field::new(left_field.name.as_str(), result_type))
}
Literal(value) => Ok(Field::new("literal", value.get_type())),
Function { func, inputs } => func.to_field(inputs.as_slice(), schema, self),
BinaryOp { op, left, right } => {
Expand Down
8 changes: 7 additions & 1 deletion tests/expressions/typing/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@
assert_typing_resolve_vs_runtime_behavior,
has_supertype,
is_comparable,
is_numeric,
)


def comparable_type_validation(lhs: DataType, rhs: DataType) -> bool:
return is_comparable(lhs) and is_comparable(rhs) and has_supertype(lhs, rhs)
return (
is_comparable(lhs)
and is_comparable(rhs)
and has_supertype(lhs, rhs)
and not ((is_numeric(lhs) and rhs == DataType.string()) or (is_numeric(rhs) and lhs == DataType.string()))
)


@pytest.mark.parametrize("op", [ops.eq, ops.ne, ops.lt, ops.le, ops.gt, ops.ge])
Expand Down
26 changes: 21 additions & 5 deletions tests/series/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
arrow_binary_types = [pa.binary(), pa.large_binary()]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
VALID_INT_STRING_COMPARISONS = list(itertools.product(arrow_int_types, repeat=2)) + list(
itertools.product(arrow_string_types, repeat=2)
)


@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str(l_dtype, r_dtype) -> None:
l_arrow = pa.array([1, 2, 3, None, 5, None])
r_arrow = pa.array([1, 3, 1, 5, None, None])
Expand All @@ -43,7 +48,7 @@ def test_comparisons_int_and_str(l_dtype, r_dtype) -> None:
assert gt == [False, False, True, None, None, None]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str_left_scalar(l_dtype, r_dtype) -> None:
l_arrow = pa.array([2])
r_arrow = pa.array([1, 2, 3, None])
Expand Down Expand Up @@ -71,7 +76,7 @@ def test_comparisons_int_and_str_left_scalar(l_dtype, r_dtype) -> None:
assert gt == [True, False, False, None]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str_right_scalar(l_dtype, r_dtype) -> None:
l_arrow = pa.array([1, 2, 3, None, 5, None])
r_arrow = pa.array([2])
Expand All @@ -98,7 +103,7 @@ def test_comparisons_int_and_str_right_scalar(l_dtype, r_dtype) -> None:
assert gt == [False, False, True, None, True, None]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str_right_null_scalar(l_dtype, r_dtype) -> None:
l_arrow = pa.array([1, 2, 3, None, 5, None])
r_arrow = pa.array([None], type=r_dtype)
Expand Down Expand Up @@ -578,7 +583,7 @@ def test_comparisons_binary_right_scalar(l_dtype, r_dtype) -> None:
assert gt == [False, False, True, None, True, None]


@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_string_types, repeat=2))
@pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS)
def test_comparisons_int_and_str_right_null_scalar(l_dtype, r_dtype) -> None:
l_arrow = pa.array([1, 2, 3, None, 5, None])
r_arrow = pa.array([None], type=r_dtype)
Expand Down Expand Up @@ -744,3 +749,14 @@ def test_compare_timestamps_diff_tz(tu1, tu2):
tz1 = Series.from_pylist([utc]).cast(DataType.timestamp(tu1, "UTC"))
tz2 = Series.from_pylist([eastern]).cast(DataType.timestamp(tu1, "US/Eastern"))
assert (tz1 == tz2).to_pylist() == [True]


@pytest.mark.parametrize("op", [operator.eq, operator.ne, operator.lt, operator.gt, operator.le, operator.ge])
def test_numeric_and_string_compare_raises_error(op):
left = Series.from_pylist([1, 2, 3])
right = Series.from_pylist(["1", "2", "3"])
with pytest.raises(ValueError, match="Cannot perform comparison on types:"):
op(left, right)

with pytest.raises(ValueError, match="Cannot perform comparison on types:"):
op(right, left)
18 changes: 11 additions & 7 deletions tests/table/test_is_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ def test_table_expr_is_in_same_types(input, items, expected) -> None:
"input,items,expected",
[
# Int
pytest.param([-1, 2, 3, 4], ["-1", "2"], [True, True, False, False], id="IntWithString"),
pytest.param([-1, 2, 3, 4], ["-1", "2"], None, id="IntWithString"),
pytest.param([1, 2, 3, 4], [1.0, 2.0], [True, True, False, False], id="IntWithFloat"),
pytest.param([0, 1, 2, 3], [True], [False, True, False, False], id="IntWithBool"),
# Float
pytest.param([-1.0, 2.0, 3.0, 4.0], ["-1.0", "2.0"], [True, True, False, False], id="FloatWithString"),
pytest.param([-1.0, 2.0, 3.0, 4.0], ["-1.0", "2.0"], None, id="FloatWithString"),
pytest.param([1.0, 2.0, 3.0, 4.0], [1, 2], [True, True, False, False], id="FloatWithInt"),
pytest.param([0.0, 1.0, 2.0, 3.0], [True], [False, True, False, False], id="FloatWithBool"),
# String
pytest.param(["1", "2", "3", "4"], [1, 2], [True, True, False, False], id="StringWithInt"),
pytest.param(["1.0", "2.0", "3.0", "4.0"], [1.0, 2.0], [True, True, False, False], id="StringWithFloat"),
pytest.param(["1", "2", "3", "4"], [1, 2], None, id="StringWithInt"),
pytest.param(["1.0", "2.0", "3.0", "4.0"], [1.0, 2.0], None, id="StringWithFloat"),
# Bool
pytest.param([True, False, None], [1, 0], [True, True, None], id="BoolWithInt"),
pytest.param([True, False, None], [1.0], [True, False, None], id="BoolWithFloat"),
Expand All @@ -104,10 +104,14 @@ def test_table_expr_is_in_same_types(input, items, expected) -> None:
)
def test_table_expr_is_in_different_types_castable(input, items, expected) -> None:
daft_table = MicroPartition.from_pydict({"input": input})
daft_table = daft_table.eval_expression_list([col("input").is_in(items)])
pydict = daft_table.to_pydict()

assert pydict["input"] == expected
if expected is None:
with pytest.raises(ValueError, match="Cannot perform comparison on types:"):
daft_table = daft_table.eval_expression_list([col("input").is_in(items)])
else:
daft_table = daft_table.eval_expression_list([col("input").is_in(items)])
pydict = daft_table.to_pydict()
assert pydict["input"] == expected


@pytest.mark.parametrize(
Expand Down
Loading