Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Dec 20, 2024
1 parent cae1b42 commit cf2e504
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ fn list_sort_helper_fixed_size(

fn general_list_fill_helper(element: &Series, num_array: &Int64Array) -> DaftResult<Vec<Series>> {
let num_iter = create_iter(num_array, element.len());
let mut result = vec![];
let mut result = Vec::with_capacity(element.len());
let element_data = element.as_physical()?;
for (row_index, num) in num_iter.enumerate() {
let list_arr = if element.is_valid(row_index) {
Expand Down
133 changes: 132 additions & 1 deletion src/daft-functions/src/list/list_fill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl ScalarUDF for ListFill {
}

Check warning on line 19 in src/daft-functions/src/list/list_fill.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/list/list_fill.rs#L17-L19

Added lines #L17 - L19 were not covered by tests

fn name(&self) -> &'static str {
"fill"
"list_fill"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
Expand Down Expand Up @@ -61,3 +61,134 @@ impl ScalarUDF for ListFill {
pub fn list_fill(n: ExprRef, elem: ExprRef) -> ExprRef {
ScalarFunction::new(ListFill {}, vec![n, elem]).into()
}

#[cfg(test)]
mod tests {
use arrow2::offset::OffsetsBuffer;
use daft_core::{
array::ListArray,
datatypes::{Int8Array, Utf8Array},
series::IntoSeries,
};
use daft_dsl::{lit, null_lit};

use super::*;

#[test]
fn test_to_field() {
let col0_null = null_lit().alias("c0");
let col0_num = lit(10).alias("c0");
let col1_null = null_lit().alias("c1");
let col1_str = lit("abc").alias("c1");

let schema = Schema::new(vec![
Field::new("c0", DataType::Int32),
Field::new("c1", DataType::Utf8),
])
.unwrap();

let fill = ListFill {};
let DaftError::SchemaMismatch(e) =
fill.to_field(&[col0_null.clone()], &schema).unwrap_err()
else {
panic!("Expected SchemaMismatch error");

Check warning on line 94 in src/daft-functions/src/list/list_fill.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/list/list_fill.rs#L94

Added line #L94 was not covered by tests
};
assert_eq!(e, "Expected 2 input args, got 1");
let DaftError::TypeError(e) = fill
.to_field(&[col0_null.clone(), col1_str.clone()], &schema)
.unwrap_err()
else {
panic!("Expected TypeError error");

Check warning on line 101 in src/daft-functions/src/list/list_fill.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-functions/src/list/list_fill.rs#L101

Added line #L101 was not covered by tests
};
assert_eq!(
e,
"Expected num field to be of numeric type, received: Null"
);

let list_of_null = fill
.to_field(&[col0_num.clone(), col1_null.clone()], &schema)
.unwrap();
let expected = Field::new("c1", DataType::List(Box::new(DataType::Null)));
assert_eq!(list_of_null, expected);
let list_of_str = fill
.to_field(&[col0_num.clone(), col1_str.clone()], &schema)
.unwrap();
let expected = Field::new("c1", DataType::List(Box::new(DataType::Utf8)));
assert_eq!(list_of_str, expected);
}

#[test]
fn test_evaluate_with_invalid_input() {
let fill = ListFill {};
let num = Int8Array::from_iter(
Field::new("s0", DataType::Int8),
vec![Some(1), Some(0), Some(10)].into_iter(),
)
.into_series();
let str = Utf8Array::from_iter("s2", vec![None, Some("hello"), Some("world")].into_iter())
.into_series();

let error = fill.evaluate(&[num.clone()]).unwrap_err();
assert_eq!(
error.to_string(),
"DaftError::ValueError Expected 2 input args, got 1"
);
}

#[test]
fn test_evaluate_mismatched_len() {
let fill = ListFill {};
let num = Int8Array::from_iter(
Field::new("s0", DataType::Int8),
vec![Some(1), Some(0), Some(10), Some(11), Some(7)].into_iter(),
)
.into_series();
let str = Utf8Array::from_iter("s2", vec![None, Some("hello"), Some("world")].into_iter())
.into_series();
let error = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
fill.evaluate(&[num.clone(), str.clone()]).unwrap()
}));
assert!(error.is_err());
}

#[test]
fn test_evaluate() -> DaftResult<()> {
let fill = ListFill {};
let num = Int8Array::from_iter(
Field::new("s0", DataType::Int8),
vec![Some(1), Some(0), Some(3)].into_iter(),
)
.into_series();
let str = Utf8Array::from_iter("s2", vec![None, Some("hello"), Some("world")].into_iter())
.into_series();
let result = fill.evaluate(&[num.clone(), str.clone()])?;
// the expected result should be a list of strings: [[None], [], ["world", "world", "world"]]
let flat_child = Utf8Array::from_iter(
"s2",
vec![None, Some("world"), Some("world"), Some("world")].into_iter(),
)
.into_series();
let offsets = vec![0, 1, 1, 4];
let offsets = OffsetsBuffer::try_from(offsets).unwrap();
let expected = ListArray::new(
Field::new("s2", DataType::List(Box::new(DataType::Utf8))),
flat_child,
offsets,
None,
);
assert_eq!(result.field(), expected.field.as_ref());
assert_eq!(result.len(), expected.len());
let result_list = result.list()?;
assert_eq!(result_list.offsets(), expected.offsets());
assert_eq!(result_list.validity(), expected.validity());
assert_eq!(
result_list
.flat_child
.utf8()?
.into_iter()
.collect::<Vec<_>>(),
expected.flat_child.utf8()?.into_iter().collect::<Vec<_>>()
);
Ok(())
}
}
51 changes: 27 additions & 24 deletions src/daft-logical-plan/src/ops/set_operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ fn check_structurally_equal(
Ok(())
}

const V_COL_L: &str = "__v_col_l";
const V_L_CNT: &str = "__v_l_cnt";
const V_COL_R: &str = "__v_col_r";
const V_R_CNT: &str = "__v_r_cnt";
const V_MIN_COUNT: &str = "__min_count";

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Intersect {
// Upstream nodes.
Expand Down Expand Up @@ -164,15 +170,13 @@ impl Intersect {
.zip(left_cols.iter())
.map(|(r, l)| r.alias(l.name()))
.collect::<Vec<ExprRef>>();
let virtual_col_l = "__v_col_l";
let virtual_col_r = "__v_col_r";
let left_v_cols = vec![
lit(true).alias(virtual_col_l),
null_lit().cast(&DataType::Boolean).alias(virtual_col_r),
lit(true).alias(V_COL_L),
null_lit().cast(&DataType::Boolean).alias(V_COL_R),
];
let right_v_cols = vec![
null_lit().cast(&DataType::Boolean).alias(virtual_col_l),
lit(true).alias(virtual_col_r),
null_lit().cast(&DataType::Boolean).alias(V_COL_L),
lit(true).alias(V_COL_R),
];
let left_v_cols = [left_v_cols, left_cols.clone()].concat();
let right_v_cols = [right_v_cols, right_cols].concat();
Expand All @@ -183,35 +187,32 @@ impl Intersect {
right_v_cols,
)?;
let one_lit = lit(1);
let left_v_cnt = col(virtual_col_l)
.count(CountMode::Valid)
.alias("__v_l_cnt");
let right_v_cnt = col(virtual_col_r)
.count(CountMode::Valid)
.alias("__v_r_cnt");
let count_name = "__min_count";
let min_count = col("__v_l_cnt")
.gt(col("__v_r_cnt"))
.if_else(col("__v_r_cnt"), col("__v_l_cnt"))
.alias(count_name);
let left_v_cnt = col(V_COL_L).count(CountMode::Valid).alias(V_L_CNT);
let right_v_cnt = col(V_COL_R).count(CountMode::Valid).alias(V_R_CNT);
let min_count = col(V_L_CNT)
.gt(col(V_R_CNT))
.if_else(col(V_R_CNT), col(V_L_CNT))
.alias(V_MIN_COUNT);
let aggregate_plan = Aggregate::try_new(
union_all.into(),
vec![left_v_cnt, right_v_cnt],
left_cols.clone(),
)?;
let filter_plan = Filter::try_new(
aggregate_plan.into(),
col("__v_l_cnt")
col(V_L_CNT)
.gt_eq(one_lit.clone())
.and(col("__v_r_cnt").gt_eq(one_lit)),
.and(col(V_R_CNT).gt_eq(one_lit)),
)?;
let min_count_plan = Project::try_new(
filter_plan.into(),
[vec![min_count], left_cols.clone()].concat(),
)?;
let fill_and_explodes = left_cols
.iter()
.map(|column| explode(list_fill(col(count_name), column.clone())))
.map(|column| {
explode(list_fill(col(V_MIN_COUNT), column.clone())).alias(column.name())
})
.collect::<Vec<_>>();
let project_plan = Project::try_new(min_count_plan.into(), fill_and_explodes)?;
Ok(project_plan.into())
Expand Down Expand Up @@ -393,6 +394,7 @@ impl Except {
.map(|(r, l)| r.alias(l.name()))
.collect::<Vec<ExprRef>>();
let virtual_col = "__v_col";
let virtual_sum = "__sum";
let left_v_cols = vec![lit(1).alias(virtual_col)];
let right_v_cols = vec![lit(-1).alias(virtual_col)];
let left_v_cols = [left_v_cols, left_cols.clone()].concat();
Expand All @@ -403,14 +405,15 @@ impl Except {
left_v_cols,
right_v_cols,
)?;
let sum_name = "__sum";
let sum = col(virtual_col).sum().alias(sum_name);
let sum = col(virtual_col).sum().alias(virtual_sum);
let aggregate_plan =
Aggregate::try_new(union_all.into(), vec![sum], left_cols.clone())?;
let filter_plan = Filter::try_new(aggregate_plan.into(), col(sum_name).gt(lit(0)))?;
let filter_plan = Filter::try_new(aggregate_plan.into(), col(virtual_sum).gt(lit(0)))?;
let fill_and_explodes = left_cols
.iter()
.map(|column| explode(list_fill(col(sum_name), column.clone())))
.map(|column| {
explode(list_fill(col(virtual_sum), column.clone())).alias(column.name())
})
.collect::<Vec<_>>();
let project_plan = Project::try_new(filter_plan.into(), fill_and_explodes)?;
Ok(project_plan.into())
Expand Down
21 changes: 21 additions & 0 deletions tests/dataframe/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,24 @@ def test_intersect_with_nulls(make_df, op, left, right, expected):
)
def test_except_with_nulls(make_df, op, left, right, expected):
helper(make_df, op, left, right, expected)


@pytest.mark.parametrize(
"op, left, right, expected",
[
(
"intersect_all",
{"foo": [1, 2, 2], "bar": [2, 3, 3]},
{"a": [2, 2, 4], "b": [3, 3, 4]},
{"foo": [2, 2], "bar": [3, 3]},
),
(
"except_all",
{"foo": [1, 2, 2], "bar": [2, 3, 3]},
{"a": [2, 2, 4], "b": [3, 3, 4]},
{"foo": [1], "bar": [2]},
),
],
)
def test_multiple_fields(make_df, op, left, right, expected):
helper(make_df, op, left, right, expected)

0 comments on commit cf2e504

Please sign in to comment.