Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
joseph-isaacs committed Nov 25, 2024
1 parent 609cee2 commit b9835b9
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 11 deletions.
4 changes: 2 additions & 2 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ where

pub struct ScalarFunctionArgs<'a> {
// The evaluated arguments to the function
pub args: Vec<ColumnarValue>,
pub args: &'a [ColumnarValue],
// The number of rows in record batch being evaluated
pub number_rows: usize,
// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`)
Expand Down Expand Up @@ -543,7 +543,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.invoke_batch(args.args.as_slice(), args.number_rows)
self.invoke_batch(args.args, args.number_rows)
}

/// Invoke the function without `args`, instead the number of rows are provided,
Expand Down
6 changes: 3 additions & 3 deletions datafusion/functions/src/datetime/date_bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ mod tests {

use crate::datetime::date_bin::{date_bin_nanos_interval, DateBinFunc};
use arrow::array::types::TimestampNanosecondType;
use arrow::array::{IntervalDayTimeArray, TimestampNanosecondArray};
use arrow::array::{Array, IntervalDayTimeArray, TimestampNanosecondArray};
use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos;
use arrow::datatypes::{DataType, TimeUnit};

Expand Down Expand Up @@ -545,10 +545,10 @@ mod tests {
milliseconds: 1,
},
))),
ColumnarValue::Array(timestamps),
ColumnarValue::Array(timestamps.clone()),
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)),
],
1,
timestamps.len(),
);
assert!(res.is_ok());

Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/datetime/to_local_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ mod tests {
fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) {
let res = ToLocalTimeFunc::new()
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(input)],
args: &[ColumnarValue::Scalar(input)],
number_rows: 1,
return_type: &expected.data_type(),
})
Expand Down
6 changes: 2 additions & 4 deletions datafusion/functions/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ pub mod test {
let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
let func = $FUNC;

let args_vec = $ARGS.iter().cloned().collect::<Vec<_>>();

let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
let cardinality = $ARGS
.iter()
Expand All @@ -151,7 +149,7 @@ pub mod test {
let return_type = return_type.unwrap();
assert_eq!(return_type, $EXPECTED_DATA_TYPE);

let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: args_vec, number_rows: cardinality, return_type: &return_type});
let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type});
assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());

let result = result.unwrap().clone().into_array(cardinality).expect("Failed to convert to array");
Expand All @@ -172,7 +170,7 @@ pub mod test {
}
else {
// invoke is expected error - cannot use .expect_err() due to Debug not being implemented
match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: args_vec, number_rows: cardinality, return_type: &return_type.unwrap()}) {
match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type.unwrap()}) {
Ok(_) => assert!(false, "expected error"),
Err(error) => {
assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl PhysicalExpr for ScalarFunctionExpr {

// evaluate the function
let output = self.fun.invoke_with_args(ScalarFunctionArgs {
args: inputs,
args: inputs.as_slice(),
number_rows: batch.num_rows(),
return_type: &self.return_type,
})?;
Expand Down

0 comments on commit b9835b9

Please sign in to comment.