Skip to content

Commit

Permalink
[minor]: Update median implementation (#13554)
Browse files Browse the repository at this point in the history
* Initial commit

* Minor changes

* Update implementation to remove Option

* Use max_by api

* Fix error

* Update datafusion/functions-aggregate/src/median.rs

---------

Co-authored-by: Oleks V <[email protected]>
  • Loading branch information
akurmustafa and comphead authored Nov 29, 2024
1 parent 55a0040 commit 8ae36b7
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
use std::mem::{size_of, size_of_val};
use std::sync::{Arc, OnceLock};
Expand All @@ -30,7 +31,7 @@ use arrow::{

use arrow::array::Array;
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::ArrowNativeType;
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType};

use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue};
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
Expand Down Expand Up @@ -310,6 +311,21 @@ impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
}
}

/// Get maximum entry in the slice,
fn slice_max<T>(array: &[T::Native]) -> T::Native
where
T: ArrowPrimitiveType,
T::Native: PartialOrd, // Ensure the type supports PartialOrd for comparison
{
// Make sure that, array is not empty.
debug_assert!(!array.is_empty());
// `.unwrap()` is safe here as the array is supposed to be non-empty
*array
.iter()
.max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Less))
.unwrap()
}

fn calculate_median<T: ArrowNumericType>(
mut values: Vec<T::Native>,
) -> Option<T::Native> {
Expand All @@ -320,8 +336,11 @@ fn calculate_median<T: ArrowNumericType>(
None
} else if len % 2 == 0 {
let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp);
let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2));
// Get the maximum of the low (left side after bi-partitioning)
let left_max = slice_max::<T>(low);
let median = left_max
.add_wrapping(*high)
.div_wrapping(T::Native::usize_as(2));
Some(median)
} else {
let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
Expand Down

0 comments on commit 8ae36b7

Please sign in to comment.