Skip to content

Commit

Permalink
[BUG] Fix if_else series naming from predicate broadcast (#2051)
Browse files Browse the repository at this point in the history
When predicate is broadcasted in the if_else kernel and the predicate is
false, the if_false series is cloned and the name is unchanged, which
causes a
```
ValueError: DaftError::ComputeError Mismatch of expected expression name and name from computed series
```
at expression evaluation time because the expected field name is the
[if_true series name.
](https://github.com/Eventual-Inc/Daft/blob/main/src/daft-dsl/src/expr.rs#L530)

Example code to reproduce:
```
df = daft.from_pydict({"predicate": [False], "if_true": ["true"], "if_false": ["false"]})
df = df.select((df["predicate"] == True).if_else(df["if_true"], df["if_false"]))
df.show()
```
  • Loading branch information
colin-ho authored Mar 29, 2024
1 parent d2e7277 commit de48128
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/daft-core/src/array/ops/if_else.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ fn generic_if_else<T: GrowableArray + FullNull + Clone + IntoSeries>(
None => Ok(T::full_null(name, dtype, lhs_len).into_series()),
Some(predicate_scalar_value) => {
if predicate_scalar_value {
Ok(lhs.clone().into_series())
Ok(lhs.clone().into_series().rename(name))
} else {
Ok(rhs.clone().into_series())
Ok(rhs.clone().into_series().rename(name))
}
}
};
Expand Down
25 changes: 25 additions & 0 deletions tests/table/test_if_else.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

import pytest

from daft.expressions import col
from daft.table.micropartition import MicroPartition


@pytest.mark.parametrize(
["predicate", "if_true", "if_false", "expected"],
[
# Single row
([True], [1], [2], [1]),
([False], [1], [2], [2]),
# Multiple rows
([True, False, True], [1, 2, 3], [4, 5, 6], [1, 5, 3]),
([False, False, False], [1, 2, 3], [4, 5, 6], [4, 5, 6]),
],
)
def test_table_expr_if_else(predicate, if_true, if_false, expected) -> None:
daft_table = MicroPartition.from_pydict({"predicate": predicate, "if_true": if_true, "if_false": if_false})
daft_table = daft_table.eval_expression_list([col("predicate").if_else(col("if_true"), col("if_false"))])
pydict = daft_table.to_pydict()

assert pydict["if_true"] == expected

0 comments on commit de48128

Please sign in to comment.