Skip to content

Commit

Permalink
[FEAT] Add count-distinct aggregation (#3455)
Browse files Browse the repository at this point in the history
# Overview
Add a count-distinct operation.
- counts the unique number of elements inside of a column
- will *not* count any `NULL`s, should they appear

## Example

```py
df = daft.from_pydict({'a': [1,2,3,None,4]})
df.agg(daft.col('a').count_distinct()).show()
```

Will result in `4` being outputted.

## Implementation

### Uni-partitioned
- calls `Series::count_distinct` which performs:
  - list_agg
  - list_unique_count

### Multi-partitioned
- transforms into:
  - list_agg (stage 1)
  - list_concat (stage 2)
  - list_unique_count (final projection)
  • Loading branch information
Raunak Bhagat authored Dec 3, 2024
1 parent 830f8b7 commit 75ad85a
Show file tree
Hide file tree
Showing 18 changed files with 261 additions and 6 deletions.
2 changes: 2 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ class PyExpr:
def cast(self, dtype: PyDataType) -> PyExpr: ...
def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ...
def count(self, mode: CountMode) -> PyExpr: ...
def count_distinct(self) -> PyExpr: ...
def sum(self) -> PyExpr: ...
def approx_count_distinct(self) -> PyExpr: ...
def approx_percentiles(self, percentiles: float | list[float]) -> PyExpr: ...
Expand Down Expand Up @@ -1367,6 +1368,7 @@ class PySeries:
) -> PySeries: ...
def __invert__(self) -> PySeries: ...
def count(self, mode: CountMode) -> PySeries: ...
def count_distinct(self) -> PySeries: ...
def sum(self) -> PySeries: ...
def mean(self) -> PySeries: ...
def stddev(self) -> PySeries: ...
Expand Down
4 changes: 4 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,10 @@ def count(self, mode: Literal["all", "valid", "null"] | CountMode = CountMode.Va
expr = self._expr.count(mode)
return Expression._from_pyexpr(expr)

def count_distinct(self) -> Expression:
expr = self._expr.count_distinct()
return Expression._from_pyexpr(expr)

def sum(self) -> Expression:
"""Calculates the sum of the values in the expression"""
expr = self._expr.sum()
Expand Down
18 changes: 18 additions & 0 deletions src/daft-core/src/array/from_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use arrow2::{
array::{MutablePrimitiveArray, PrimitiveArray},
types::months_days_ns,
};
use common_error::DaftResult;

use super::DataArray;
use crate::{
Expand Down Expand Up @@ -40,6 +41,23 @@ where
let data_array: PrimitiveArray<_> = array.into();
Self::new(field, data_array.boxed()).unwrap()
}

pub fn from_regular_iter<F, I>(field: F, iter: I) -> DaftResult<Self>
where
F: Into<Arc<Field>>,
I: Iterator<Item = Option<T::Native>>,
{
let field = field.into();
let data_type = field.dtype.to_arrow()?;
let mut array = MutablePrimitiveArray::<T::Native>::from(data_type);
let (_, upper_bound) = iter.size_hint();
if let Some(upper_bound) = upper_bound {
array.reserve(upper_bound);
}
array.extend(iter);
let array = PrimitiveArray::from(array).boxed();
Self::new(field, array)
}
}

impl Utf8Array {
Expand Down
66 changes: 64 additions & 2 deletions src/daft-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ mod ops;
mod serdes;
mod series_like;
mod utils;
use std::{ops::Sub, sync::Arc};
use std::{
collections::{hash_map::RawEntryMut, HashMap},
ops::Sub,
sync::Arc,
};

pub use array_impl::IntoSeries;
use common_display::table_display::{make_comfy_table, StrValue};
Expand All @@ -15,10 +19,14 @@ pub use ops::cast_series_to_supertype;
pub(crate) use self::series_like::SeriesLike;
use crate::{
array::{
ops::{from_arrow::FromArrow, full::FullNull, DaftCompare},
ops::{
arrow2::comparison::build_is_equal, from_arrow::FromArrow, full::FullNull, DaftCompare,
},
DataArray,
},
datatypes::{DaftDataType, DaftNumericType, DataType, Field, FieldRef, NumericNative},
prelude::AsArrow,
utils::identity_hash_set::{IdentityBuildHasher, IndexHash},
with_match_daft_types,
};

Expand All @@ -38,6 +46,60 @@ impl PartialEq for Series {
}

impl Series {
/// Build a hashset of the [`IndexHash`]s of each element in this [`Series`].
///
/// The returned hashset can be used to probe for the existence of a given element in this [`Series`].
/// Its length can also be used to determine the *exact* number of unique elements in this [`Series`].
///
/// # Note
/// 1. This function returns a `HashMap<X, ()>` rather than a `HashSet<X>`. These two types are functionally equivalent.
///
/// 2. `NULL`s are *not* inserted into the returned hashset. They won't be counted towards the final number of unique elements.
pub fn build_probe_table_without_nulls(
&self,
) -> DaftResult<HashMap<IndexHash, (), IdentityBuildHasher>> {
// Building a comparator function over a series of type `NULL` will result in a failure.
// (I.e., `let comparator = build_is_equal(..)` will fail).
//
// Therefore, exit early with an empty hashmap.
if matches!(self.data_type(), DataType::Null) {
return Ok(HashMap::default());
};

const DEFAULT_SIZE: usize = 20;
let hashed_series = self.hash_with_validity(None)?;
let array = self.to_arrow();
let comparator = build_is_equal(&*array, &*array, true, false)?;

let mut probe_table =
HashMap::<IndexHash, (), IdentityBuildHasher>::with_capacity_and_hasher(
DEFAULT_SIZE,
Default::default(),
);

for (idx, hash) in hashed_series.as_arrow().iter().enumerate() {
let hash = match hash {
Some(&hash) => hash,
None => continue,
};
let entry = probe_table.raw_entry_mut().from_hash(hash, |other| {
(hash == other.hash) && comparator(idx, other.idx as _)
});
if let RawEntryMut::Vacant(entry) = entry {
entry.insert_hashed_nocheck(
hash,
IndexHash {
idx: idx as u64,
hash,
},
(),
);
};
}

Ok(probe_table)
}

/// Exports this Series into an Arrow arrow that is corrected for the Arrow type system.
/// For example, Daft's TimestampArray is a logical type that is backed by an Int64Array Physical array.
/// If we were to call `.as_arrow()` or `.physical`on the TimestampArray, we would get an Int64Array that represented the time units.
Expand Down
5 changes: 5 additions & 0 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ impl Series {
})
}

pub fn count_distinct(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
let series = self.agg_list(groups)?.list_unique_count()?;
Ok(series)
}

pub fn sum(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
match self.data_type() {
// intX -> int64 (in line with numpy)
Expand Down
42 changes: 42 additions & 0 deletions src/daft-core/src/series/ops/list.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use common_error::{DaftError, DaftResult};
use daft_schema::field::Field;

use crate::{
datatypes::{DataType, UInt64Array, Utf8Array},
Expand Down Expand Up @@ -175,4 +176,45 @@ impl Series {
))),
}
}

/// Given a series of `List` or `FixedSizeList`, return the count of distinct elements in the list.
///
/// # Note
/// `NULL` values are not counted.
///
/// # Example
/// ```txt
/// [[1, 2, 3], [1, 1, 1], [NULL, NULL, 5]] -> [3, 1, 1]
/// ```
pub fn list_unique_count(&self) -> DaftResult<Self> {
let field = Field::new(self.name(), DataType::UInt64);
match self.data_type() {
DataType::List(..) => {
let iter = self.list()?.into_iter().map(|sub_series| {
let sub_series = sub_series?;
let length = sub_series
.build_probe_table_without_nulls()
.expect("Building the probe table should always work")
.len() as u64;
Some(length)
});
Ok(UInt64Array::from_regular_iter(field, iter)?.into_series())
}
DataType::FixedSizeList(..) => {
let iter = self.fixed_size_list()?.into_iter().map(|sub_series| {
let sub_series = sub_series?;
let length = sub_series
.build_probe_table_without_nulls()
.expect("Building the probe table should always work")
.len() as u64;
Some(length)
});
Ok(UInt64Array::from_regular_iter(field, iter)?.into_series())
}
_ => Err(DaftError::TypeError(format!(
"List count distinct not implemented for {}",
self.data_type()
))),
}
}
}
18 changes: 16 additions & 2 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ pub enum AggExpr {
#[display("count({_0}, {_1})")]
Count(ExprRef, CountMode),

#[display("count_distinct({_0})")]
CountDistinct(ExprRef),

#[display("sum({_0})")]
Sum(ExprRef),

Expand Down Expand Up @@ -247,6 +250,7 @@ impl AggExpr {
pub fn name(&self) -> &str {
match self {
Self::Count(expr, ..)
| Self::CountDistinct(expr)
| Self::Sum(expr)
| Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. })
| Self::ApproxCountDistinct(expr)
Expand All @@ -269,6 +273,10 @@ impl AggExpr {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_count({mode})"))
}
Self::CountDistinct(expr) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_count_distinct()"))
}
Self::Sum(expr) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_sum()"))
Expand Down Expand Up @@ -337,6 +345,7 @@ impl AggExpr {
pub fn children(&self) -> Vec<ExprRef> {
match self {
Self::Count(expr, ..)
| Self::CountDistinct(expr)
| Self::Sum(expr)
| Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. })
| Self::ApproxCountDistinct(expr)
Expand All @@ -361,7 +370,8 @@ impl AggExpr {
}
let mut first_child = || children.pop().unwrap();
match self {
Self::Count(_, count_mode) => Self::Count(first_child(), *count_mode),
&Self::Count(_, count_mode) => Self::Count(first_child(), count_mode),
Self::CountDistinct(_) => Self::CountDistinct(first_child()),
Self::Sum(_) => Self::Sum(first_child()),
Self::Mean(_) => Self::Mean(first_child()),
Self::Stddev(_) => Self::Stddev(first_child()),
Expand Down Expand Up @@ -391,7 +401,7 @@ impl AggExpr {

pub fn to_field(&self, schema: &Schema) -> DaftResult<Field> {
match self {
Self::Count(expr, ..) => {
Self::Count(expr, ..) | Self::CountDistinct(expr) => {
let field = expr.to_field(schema)?;
Ok(Field::new(field.name.as_str(), DataType::UInt64))
}
Expand Down Expand Up @@ -539,6 +549,10 @@ impl Expr {
Self::Agg(AggExpr::Count(self, mode)).into()
}

pub fn count_distinct(self: ExprRef) -> ExprRef {
Self::Agg(AggExpr::CountDistinct(self)).into()
}

pub fn sum(self: ExprRef) -> ExprRef {
Self::Agg(AggExpr::Sum(self)).into()
}
Expand Down
4 changes: 4 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ impl PyExpr {
Ok(self.expr.clone().count(mode).into())
}

pub fn count_distinct(&self) -> PyResult<Self> {
Ok(self.expr.clone().count_distinct().into())
}

pub fn sum(&self) -> PyResult<Self> {
Ok(self.expr.clone().sum().into())
}
Expand Down
2 changes: 2 additions & 0 deletions src/daft-functions/src/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod min;
mod slice;
mod sort;
mod sum;
mod unique_count;
mod value_counts;

pub use chunk::{list_chunk as chunk, ListChunk};
Expand All @@ -22,4 +23,5 @@ pub use min::{list_min as min, ListMin};
pub use slice::{list_slice as slice, ListSlice};
pub use sort::{list_sort as sort, ListSort};
pub use sum::{list_sum as sum, ListSum};
pub use unique_count::{list_unique_count as unique_count, ListUniqueCount};
pub use value_counts::list_value_counts as value_counts;
54 changes: 54 additions & 0 deletions src/daft-functions/src/list/unique_count.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use std::any::Any;

use common_error::{DaftError, DaftResult};
use daft_core::{
prelude::{DataType, Field, Schema},
series::Series,
};
use daft_dsl::{
functions::{ScalarFunction, ScalarUDF},
ExprRef,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct ListUniqueCount;

#[typetag::serde]
impl ScalarUDF for ListUniqueCount {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &'static str {
"list_unique_count"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[input] => {
let field = input.to_field(schema)?;
Ok(Field::new(field.name, DataType::UInt64))
}
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
match inputs {
[input] => input.list_unique_count(),
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}
}

#[must_use]
pub fn list_unique_count(expr: ExprRef) -> ExprRef {
ScalarFunction::new(ListUniqueCount, vec![expr]).into()
}
1 change: 1 addition & 0 deletions src/daft-functions/src/python/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use daft_dsl::python::PyExpr;
use pyo3::{pyfunction, PyResult};

simple_python_wrapper!(list_chunk, crate::list::chunk, [expr: PyExpr, size: usize]);
simple_python_wrapper!(list_unique_count, crate::list::unique_count, [expr: PyExpr]);
simple_python_wrapper!(list_count, crate::list::count, [expr: PyExpr, mode: CountMode]);
simple_python_wrapper!(explode, crate::list::explode, [expr: PyExpr]);
simple_python_wrapper!(list_get, crate::list::get, [expr: PyExpr, idx: PyExpr, default_value: PyExpr]);
Expand Down
1 change: 1 addition & 0 deletions src/daft-functions/src/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub fn register(parent: &Bound<PyModule>) -> PyResult<()> {
add!(list::list_slice);
add!(list::list_sort);
add!(list::list_sum);
add!(list::list_unique_count);
add!(list::list_value_counts);

add!(misc::to_struct);
Expand Down
4 changes: 4 additions & 0 deletions src/daft-logical-plan/src/ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ fn replace_column_with_semantic_id_aggexpr(
|_| e,
)
}
AggExpr::CountDistinct(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::CountDistinct, |_| e)
}
AggExpr::Sum(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Sum, |_| e)
Expand Down
Loading

0 comments on commit 75ad85a

Please sign in to comment.