From ba181af99a966c763c7a5a99a5971d38bb6ef52c Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 25 Sep 2024 12:10:10 -0700 Subject: [PATCH] tests pass yay --- .vscode/launch.json | 36 +++ .vscode/settings.json | 7 +- daft/daft/__init__.pyi | 1 + daft/expressions/expressions.py | 11 + src/arrow2/src/array/map/mod.rs | 26 +- src/arrow2/src/offset.rs | 33 +-- .../src/array/fixed_size_list_array.rs | 1 + src/daft-core/src/array/list_array.rs | 2 + .../src/array/ops/arrow2/comparison.rs | 2 +- src/daft-core/src/array/ops/cast.rs | 2 +- src/daft-core/src/array/ops/from_arrow.rs | 4 +- src/daft-core/src/array/ops/list.rs | 239 +++++++++++++++++- src/daft-core/src/array/ops/map.rs | 21 +- src/daft-core/src/array/struct_array.rs | 2 + src/daft-core/src/datatypes/matching.rs | 2 +- src/daft-core/src/lib.rs | 1 + src/daft-core/src/series/from.rs | 3 +- src/daft-core/src/series/ops/list.rs | 16 ++ src/daft-core/src/series/ops/map.rs | 2 +- src/daft-core/src/series/serdes.rs | 2 +- src/daft-dsl/src/functions/map/get.rs | 16 +- src/daft-dsl/src/functions/utf8/split.rs | 1 + src/daft-functions/src/list/mod.rs | 5 + src/daft-functions/src/list/value_counts.rs | 70 +++++ src/daft-schema/src/dtype.rs | 119 +++++---- src/daft-schema/src/python/datatype.rs | 8 +- src/daft-sql/src/planner.rs | 2 +- src/daft-stats/src/column_stats/mod.rs | 2 +- src/daft-table/src/repr_html.rs | 2 +- tests/expressions/test_expressions.py | 58 +++++ 30 files changed, 583 insertions(+), 113 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 src/daft-functions/src/list/value_counts.rs diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..84a59357e2 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,36 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Debug Daft Python/Rust", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/tests/expressions/test_expressions.py", + "args": [], + "console": "integratedTerminal", + "justMyCode": false, + "env": { + "PYTHONPATH": "${workspaceFolder}" + }, + "serverReadyAction": { + "pattern": "pID = ([0-9]+)", + "action": "startDebugging", + "name": "Daft Rust LLDB" + } + }, + { + "name": "Daft Rust LLDB", + "pid": "0", + "type": "lldb", + "request": "attach", + "program": "${workspaceFolder}/.venv/bin/python", + "stopOnEntry": false, + "sourceLanguages": [ + "rust" + ], + "presentation": { + "hidden": true + } + } + ] + } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 2f8da01d92..1038ded445 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,5 +3,10 @@ "CARGO_TARGET_DIR": "target/analyzer" }, "rust-analyzer.check.features": "all", - "rust-analyzer.cargo.features": "all" + "rust-analyzer.cargo.features": "all", + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index e2cb5e1eaa..b4797c7709 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1268,6 +1268,7 @@ def dt_truncate(expr: PyExpr, interval: str, relative_to: PyExpr) -> PyExpr: ... # --- def explode(expr: PyExpr) -> PyExpr: ... def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ... +def list_value_counts(expr: PyExpr) -> PyExpr: ... def list_join(expr: PyExpr, delimiter: PyExpr) -> PyExpr: ... def list_count(expr: PyExpr, mode: CountMode) -> PyExpr: ... def list_get(expr: PyExpr, idx: PyExpr, default: PyExpr) -> PyExpr: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 1ae7e90dac..a97ac924e7 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -2922,6 +2922,17 @@ def join(self, delimiter: str | Expression) -> Expression: delimiter_expr = Expression._to_expression(delimiter) return Expression._from_pyexpr(native.list_join(self._expr, delimiter_expr._expr)) + def value_counts(self) -> Expression: + """Counts the occurrences of each unique value in the list. + + Returns: + Expression: A list of structs, where each struct contains a 'value' field + representing a unique element from the original list, and a 'count' field + representing the number of times that value appears in the list. + """ + return Expression._from_pyexpr(native.list_value_counts(self._expr)) + + def count(self, mode: CountMode = CountMode.Valid) -> Expression: """Counts the number of elements in each list diff --git a/src/arrow2/src/array/map/mod.rs b/src/arrow2/src/array/map/mod.rs index d0dcb46efb..a870de834f 100644 --- a/src/arrow2/src/array/map/mod.rs +++ b/src/arrow2/src/array/map/mod.rs @@ -41,20 +41,24 @@ impl MapArray { try_check_offsets_bounds(&offsets, field.len())?; let inner_field = Self::try_get_field(&data_type)?; - if let DataType::Struct(inner) = inner_field.data_type() { - if inner.len() != 2 { - return Err(Error::InvalidArgumentError( - "MapArray's inner `Struct` must have 2 fields (keys and maps)".to_string(), - )); - } - } else { + + let inner_data_type = inner_field.data_type(); + let DataType::Struct(inner) = inner_data_type else { return Err(Error::InvalidArgumentError( - "MapArray expects `DataType::Struct` as its inner logical type".to_string(), + format!("MapArray expects `DataType::Struct` as its inner logical type, but found {inner_data_type:?}"), )); + }; + + let inner_len = inner.len(); + if inner_len != 2 { + let msg = format!("MapArray's inner `Struct` must have 2 fields (keys and maps), but found {} fields", inner_len); + return Err(Error::InvalidArgumentError(msg)); } - if field.data_type() != inner_field.data_type() { + + let field_data_type = field.data_type(); + if field_data_type != inner_field.data_type() { return Err(Error::InvalidArgumentError( - "MapArray expects `field.data_type` to match its inner DataType".to_string(), + format!("MapArray expects `field.data_type` to match its inner DataType, but found \n{field_data_type:?}\nvs\n\n\n{inner_field:?}"), )); } @@ -66,7 +70,7 @@ impl MapArray { "validity mask length must match the number of values", )); } - + Ok(Self { data_type, field, diff --git a/src/arrow2/src/offset.rs b/src/arrow2/src/offset.rs index 80b45d6680..1ab6f15105 100644 --- a/src/arrow2/src/offset.rs +++ b/src/arrow2/src/offset.rs @@ -71,7 +71,7 @@ impl Offsets { /// Creates a new [`Offsets`] from an iterator of lengths #[inline] - pub fn try_from_iter>(iter: I) -> Result { + pub fn try_from_iter>(iter: I) -> Result { let iterator = iter.into_iter(); let (lower, _) = iterator.size_hint(); let mut offsets = Self::with_capacity(lower); @@ -144,10 +144,7 @@ impl Offsets { /// Returns the last offset of this container. #[inline] pub fn last(&self) -> &O { - match self.0.last() { - Some(element) => element, - None => unsafe { unreachable_unchecked() }, - } + self.0.last().unwrap_or_else(|| unsafe { unreachable_unchecked() }) } /// Returns a range (start, end) corresponding to the position `index` @@ -215,7 +212,7 @@ impl Offsets { /// # Errors /// This function errors iff this operation overflows for the maximum value of `O`. #[inline] - pub fn try_from_lengths>(lengths: I) -> Result { + pub fn try_from_lengths>(lengths: I) -> Result { let mut self_ = Self::with_capacity(lengths.size_hint().0); self_.try_extend_from_lengths(lengths)?; Ok(self_) @@ -225,7 +222,7 @@ impl Offsets { /// # Errors /// This function errors iff this operation overflows for the maximum value of `O`. #[inline] - pub fn try_extend_from_lengths>( + pub fn try_extend_from_lengths>( &mut self, lengths: I, ) -> Result<(), Error> { @@ -401,22 +398,26 @@ impl OffsetsBuffer { *self.last() - *self.first() } + pub fn ranges(&self) -> impl Iterator> + '_ { + self.0.windows(2).map(|w| { + let from = w[0]; + let to = w[1]; + debug_assert!(from <= to, "offsets must be monotonically increasing"); + from..to + }) + } + + /// Returns the first offset. #[inline] pub fn first(&self) -> &O { - match self.0.first() { - Some(element) => element, - None => unsafe { unreachable_unchecked() }, - } + self.0.first().unwrap_or_else(|| unsafe { unreachable_unchecked() }) } /// Returns the last offset. #[inline] pub fn last(&self) -> &O { - match self.0.last() { - Some(element) => element, - None => unsafe { unreachable_unchecked() }, - } + self.0.last().unwrap_or_else(|| unsafe { unreachable_unchecked() }) } /// Returns a range (start, end) corresponding to the position `index` @@ -460,7 +461,7 @@ impl OffsetsBuffer { /// Returns an iterator with the lengths of the offsets #[inline] - pub fn lengths(&self) -> impl Iterator + '_ { + pub fn lengths(&self) -> impl Iterator + '_ { self.0.windows(2).map(|w| (w[1] - w[0]).to_usize()) } diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index a8b5048b82..830c216177 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -11,6 +11,7 @@ use crate::{ #[derive(Clone, Debug)] pub struct FixedSizeListArray { pub field: Arc, + /// contains all the elements of the nested lists flattened into a single contiguous array. pub flat_child: Series, validity: Option, } diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index 538c24e716..416cad9840 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -12,6 +12,8 @@ use crate::{ pub struct ListArray { pub field: Arc, pub flat_child: Series, + + /// Where does each row start offsets: arrow2::offset::OffsetsBuffer, validity: Option, } diff --git a/src/daft-core/src/array/ops/arrow2/comparison.rs b/src/daft-core/src/array/ops/arrow2/comparison.rs index 37f7b2a37b..700ab4f8d0 100644 --- a/src/daft-core/src/array/ops/arrow2/comparison.rs +++ b/src/daft-core/src/array/ops/arrow2/comparison.rs @@ -49,7 +49,7 @@ fn build_is_equal_with_nan( } } -fn build_is_equal( +pub fn build_is_equal( left: &dyn Array, right: &dyn Array, nulls_equal: bool, diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index c3dbe0c209..9de75587a3 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -2091,7 +2091,7 @@ impl ListArray { } } } - DataType::Map(..) => Ok(MapArray::new( + DataType::Map { .. } => Ok(MapArray::new( Field::new(self.name(), dtype.clone()), self.clone(), ) diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs index 1739b524a9..2575cc6f91 100644 --- a/src/daft-core/src/array/ops/from_arrow.rs +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -35,7 +35,7 @@ where // TODO: Consolidate Map to use the same .to_type conversion as other logical types // Currently, .to_type does not work for Map in Arrow2 because it requires physical types to be equivalent, // but the physical type of MapArray in Arrow2 is a MapArray, not a ListArray - DataType::Map(..) => arrow_arr, + DataType::Map { .. } => arrow_arr, _ => arrow_arr.to_type(data_array_field.dtype.to_arrow()?), }; let physical = ::ArrayType::from_arrow( @@ -98,7 +98,7 @@ impl FromArrow for ListArray { arrow_arr.validity().cloned(), )) } - (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map(..)) => { + (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map { .. }) => { let map_arr = arrow_arr .as_any() .downcast_ref::() diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 4dd8cee2a8..0a6f25530b 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -2,16 +2,24 @@ use std::{iter::repeat, sync::Arc}; use arrow2::offset::OffsetsBuffer; use common_error::DaftResult; +use indexmap::{ + map::{raw_entry_v1::RawEntryMut, RawEntryApiV1}, + IndexMap, +}; use super::as_arrow::AsArrow; use crate::{ array::{ growable::{make_growable, Growable}, - FixedSizeListArray, ListArray, + ops::arrow2::comparison::build_is_equal, + FixedSizeListArray, ListArray, StructArray, }, count_mode::CountMode, datatypes::{BooleanArray, DataType, Field, Int64Array, UInt64Array, Utf8Array}, + kernels::search_sorted::build_is_valid, + prelude::MapArray, series::{IntoSeries, Series}, + utils::identity_hash_set::IdentityBuildHasher, }; fn join_arrow_list_of_utf8s( @@ -244,6 +252,130 @@ fn list_sort_helper_fixed_size( } impl ListArray { + pub fn value_counts(&self) -> DaftResult { + struct IndexRef { + index: usize, + hash: u64, + } + + impl std::hash::Hash for IndexRef { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } + } + + let original_name = self.name(); + + let hashes = self.flat_child.hash(None)?; + + let flat_child = self.flat_child.to_arrow(); + let flat_child = &*flat_child; + + let is_equal = build_is_equal( + flat_child, flat_child, true, // todo: should nulls and nans be considered equal? + true, + )?; + + let is_valid = build_is_valid(flat_child); + + let key_type = self.flat_child.data_type().clone(); + let count_type = DataType::UInt64; + + let mut include_mask = Vec::with_capacity(self.flat_child.len()); + let mut count_array = Vec::new(); + + let mut offsets = Vec::with_capacity(self.len()); + + offsets.push(0_i64); + + let mut map: IndexMap = IndexMap::default(); + for range in self.offsets().ranges() { + map.clear(); + + for index in range { + let index = index as usize; + if !is_valid(index) { + include_mask.push(false); + // skip nulls + continue; + } + + let hash = hashes.get(index).unwrap(); + + let entry = map + .raw_entry_mut_v1() + .from_hash(hash, |other| is_equal(other.index, index)); + + match entry { + RawEntryMut::Occupied(mut entry) => { + include_mask.push(false); + *entry.get_mut() += 1; + } + RawEntryMut::Vacant(vacant) => { + include_mask.push(true); + vacant.insert(IndexRef { hash, index }, 1); + } + } + } + + // indexmap so ordered + for v in map.values() { + count_array.push(*v); + } + + offsets.push(count_array.len() as i64); + } + + let values = UInt64Array::from(("count", count_array)).into_series(); + let boolean_array = BooleanArray::from(("boolean", include_mask.as_slice())); + + let keys = self.flat_child.filter(&boolean_array)?; + + // todo: probably inefficient + let keys = Series::try_from_field_and_arrow_array( + Field::new("key", key_type.clone()), + keys.to_arrow(), + )?; + + // todo: probably inefficient + let values = Series::try_from_field_and_arrow_array( + Field::new("value", count_type.clone()), + values.to_arrow(), + )?; + + let struct_type = DataType::Struct(vec![ + Field::new("key", key_type.clone()), + Field::new("value", count_type.clone()), + ]); + + let struct_array = StructArray::new( + Arc::new(Field::new("entries", struct_type.clone())), + vec![keys, values], + None, + ); + + let list_type = DataType::List(Box::new(struct_type)); + + let offsets = OffsetsBuffer::try_from(offsets)?; + + let list_array = Self::new( + Arc::new(Field::new("entries", list_type.clone())), + struct_array.into_series(), + offsets, + None, + ); + + let map_type = DataType::Map { + key: Box::new(key_type), + value: Box::new(count_type), + }; + + Ok(MapArray::new( + Field::new(original_name, map_type.clone()), + list_array, + )) + } + pub fn count(&self, mode: CountMode) -> DaftResult { let counts = match (mode, self.flat_child.validity()) { (CountMode::All, _) | (CountMode::Valid, None) => { @@ -472,6 +604,111 @@ impl ListArray { } impl FixedSizeListArray { + // this DaftResult? or something or Series or what + + pub fn value_counts(&self) -> DaftResult { + struct IndexRef { + index: usize, + hash: u64, + } + + impl std::hash::Hash for IndexRef { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } + } + + let hashes = self.flat_child.hash(None)?; + let is_equal = build_is_equal( + self.flat_child.to_arrow().as_ref(), + self.flat_child.to_arrow().as_ref(), + true, // todo: should nulls and nans be considered equal? + true, + )?; + + let key_type = self.flat_child.data_type().clone(); + let count_type = DataType::UInt64; + + let fixed_size = self.fixed_element_len(); + + let mut map: IndexMap = IndexMap::default(); + + let mut booleans = Vec::with_capacity(self.flat_child.len()); + let mut count_array = Vec::new(); + let mut offsets = Vec::with_capacity(self.len()); + + offsets.push(0_i64); + + for i in 0..self.len() { + map.clear(); + let start_index = i * fixed_size; + for j in 0..fixed_size { + let index = start_index + j; + + let hash = hashes.get(index).unwrap(); + + let entry = map + .raw_entry_mut_v1() + .from_hash(hash, |other| is_equal(other.index, index)); + + match entry { + RawEntryMut::Occupied(mut entry) => { + booleans.push(false); + *entry.get_mut() += 1; + } + RawEntryMut::Vacant(vacant) => { + booleans.push(true); + vacant.insert(IndexRef { hash, index }, 1); + } + } + } + + // indexmap so ordered + for v in map.values() { + count_array.push(*v); + } + + offsets.push(count_array.len() as i64); + } + + let values = UInt64Array::from(("count", count_array)).into_series(); + let boolean_array = BooleanArray::from(("boolean", booleans.as_slice())); + + let keys = self.flat_child.filter(&boolean_array)?; + + let struct_type = DataType::Struct(vec![ + Field::new("key", key_type.clone()), + Field::new("value", count_type.clone()), + ]); + + let struct_array = StructArray::new( + Arc::new(Field::new("entries", struct_type.clone())), + vec![keys, values], + None, + ); + + let list_type = DataType::List(Box::new(struct_type)); + + let offsets = OffsetsBuffer::try_from(offsets)?; + + let list_array = ListArray::new( + Arc::new(Field::new("entries", list_type.clone())), + struct_array.into_series(), + offsets, + None, + ); + + let map_type = DataType::Map { + key: Box::new(key_type), + value: Box::new(count_type), + }; + + Ok(MapArray::new( + Field::new("entries", map_type.clone()), + list_array, + )) + } + pub fn count(&self, mode: CountMode) -> DaftResult { let size = self.fixed_element_len(); let counts = match (mode, self.flat_child.validity()) { diff --git a/src/daft-core/src/array/ops/map.rs b/src/daft-core/src/array/ops/map.rs index 3b2f6ffd8c..a1613ce19c 100644 --- a/src/daft-core/src/array/ops/map.rs +++ b/src/daft-core/src/array/ops/map.rs @@ -24,19 +24,10 @@ fn single_map_get(structs: &Series, key_to_get: &Series) -> DaftResult { impl MapArray { pub fn map_get(&self, key_to_get: &Series) -> DaftResult { - let value_type = if let DataType::Map(inner_dtype) = self.data_type() { - match *inner_dtype.clone() { - DataType::Struct(fields) if fields.len() == 2 => { - fields[1].dtype.clone() - } - _ => { - return Err(DaftError::TypeError(format!( - "Expected inner type to be a struct type with two fields: key and value, got {:?}", - inner_dtype - ))) - } - } - } else { + let DataType::Map { + value: value_type, .. + } = self.data_type() + else { return Err(DaftError::TypeError(format!( "Expected input to be a map type, got {:?}", self.data_type() @@ -49,7 +40,7 @@ impl MapArray { for series in self.physical.into_iter() { match series { Some(s) if !s.is_empty() => result.push(single_map_get(&s, key_to_get)?), - _ => result.push(Series::full_null("value", &value_type, 1)), + _ => result.push(Series::full_null("value", value_type, 1)), } } Series::concat(&result.iter().collect::>()) @@ -59,7 +50,7 @@ impl MapArray { for (i, series) in self.physical.into_iter().enumerate() { match (series, key_to_get.slice(i, i + 1)?) { (Some(s), k) if !s.is_empty() => result.push(single_map_get(&s, &k)?), - _ => result.push(Series::full_null("value", &value_type, 1)), + _ => result.push(Series::full_null("value", value_type, 1)), } } Series::concat(&result.iter().collect::>()) diff --git a/src/daft-core/src/array/struct_array.rs b/src/daft-core/src/array/struct_array.rs index 996680ede5..8a228735e4 100644 --- a/src/daft-core/src/array/struct_array.rs +++ b/src/daft-core/src/array/struct_array.rs @@ -11,6 +11,8 @@ use crate::{ #[derive(Clone, Debug)] pub struct StructArray { pub field: Arc, + + /// Column representations pub children: Vec, validity: Option, len: usize, diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index b8b8e1660f..bae597393c 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -31,7 +31,7 @@ macro_rules! with_match_daft_types {( FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, List(_) => __with_ty__! { ListType }, Struct(_) => __with_ty__! { StructType }, - Map(_) => __with_ty__! { MapType }, + Map{..} => __with_ty__! { MapType }, Extension(_, _, _) => __with_ty__! { ExtensionType }, #[cfg(feature = "python")] Python => __with_ty__! { PythonType }, diff --git a/src/daft-core/src/lib.rs b/src/daft-core/src/lib.rs index 322a0db3ec..5892f75ffb 100644 --- a/src/daft-core/src/lib.rs +++ b/src/daft-core/src/lib.rs @@ -2,6 +2,7 @@ #![feature(int_roundings)] #![feature(iterator_try_reduce)] #![feature(if_let_guard)] +#![feature(hash_raw_entry)] pub mod array; pub mod count_mode; diff --git a/src/daft-core/src/series/from.rs b/src/daft-core/src/series/from.rs index 99776edf64..746ba136eb 100644 --- a/src/daft-core/src/series/from.rs +++ b/src/daft-core/src/series/from.rs @@ -12,9 +12,10 @@ use crate::{ impl Series { pub fn try_from_field_and_arrow_array( - field: Arc, + field: impl Into>, array: Box, ) -> DaftResult { + let field = field.into(); // TODO(Nested): Refactor this out with nested logical types in StructArray and ListArray // Corner-case nested logical types that have not yet been migrated to new Array formats // to hold only casted physical arrow arrays. diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index d9a17dd087..81a4788067 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -7,6 +7,22 @@ use crate::{ }; impl Series { + pub fn list_value_counts(&self) -> DaftResult { + let series = match self.data_type() { + DataType::List(_) => self.list()?.value_counts(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.value_counts(), + dt => { + return Err(DaftError::TypeError(format!( + "List contains not implemented for {}", + dt + ))) + } + }? + .into_series(); + + Ok(series) + } + pub fn explode(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.explode(), diff --git a/src/daft-core/src/series/ops/map.rs b/src/daft-core/src/series/ops/map.rs index b624cd8aac..0d4e54820d 100644 --- a/src/daft-core/src/series/ops/map.rs +++ b/src/daft-core/src/series/ops/map.rs @@ -5,7 +5,7 @@ use crate::{datatypes::DataType, series::Series}; impl Series { pub fn map_get(&self, key: &Self) -> DaftResult { match self.data_type() { - DataType::Map(_) => self.map()?.map_get(key), + DataType::Map { .. } => self.map()?.map_get(key), dt => Err(DaftError::TypeError(format!( "map.get not implemented for {}", dt diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index bf7e42a1e0..3ce3b6f881 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -163,7 +163,7 @@ impl<'d> serde::Deserialize<'d> for Series { .unwrap() .into_series()) } - DataType::Map(..) => { + DataType::Map { .. } => { let physical = map.next_value::()?; Ok(MapArray::new( Arc::new(field), diff --git a/src/daft-dsl/src/functions/map/get.rs b/src/daft-dsl/src/functions/map/get.rs index ab6eb148f8..bf5f9efdf0 100644 --- a/src/daft-dsl/src/functions/map/get.rs +++ b/src/daft-dsl/src/functions/map/get.rs @@ -13,18 +13,14 @@ impl FunctionEvaluator for GetEvaluator { fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { match inputs { + // what is input and what is key + // input is a map field [input, key] => match (input.to_field(schema), key.to_field(schema)) { (Ok(input_field), Ok(_)) => match input_field.dtype { - DataType::Map(inner) => match inner.as_ref() { - DataType::Struct(fields) if fields.len() == 2 => { - let value_dtype = &fields[1].dtype; - Ok(Field::new("value", value_dtype.clone())) - } - _ => Err(DaftError::TypeError(format!( - "Expected input map to have struct values with 2 fields, got {}", - inner - ))), - }, + DataType::Map { value, .. } => { + // todo: perhaps better naming + Ok(Field::new("value", *value)) + } _ => Err(DaftError::TypeError(format!( "Expected input to be a map, got {}", input_field.dtype diff --git a/src/daft-dsl/src/functions/utf8/split.rs b/src/daft-dsl/src/functions/utf8/split.rs index 0518786055..c0e121b393 100644 --- a/src/daft-dsl/src/functions/utf8/split.rs +++ b/src/daft-dsl/src/functions/utf8/split.rs @@ -8,6 +8,7 @@ pub(super) struct SplitEvaluator {} impl FunctionEvaluator for SplitEvaluator { fn fn_name(&self) -> &'static str { + println!("hi"); "split" } diff --git a/src/daft-functions/src/list/mod.rs b/src/daft-functions/src/list/mod.rs index 2ba3f197be..c0ad745b19 100644 --- a/src/daft-functions/src/list/mod.rs +++ b/src/daft-functions/src/list/mod.rs @@ -9,6 +9,7 @@ mod min; mod slice; mod sort; mod sum; +mod value_counts; pub use chunk::{list_chunk as chunk, ListChunk}; pub use count::{list_count as count, ListCount}; @@ -31,6 +32,10 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(count::py_list_count, parent)?)?; parent.add_function(wrap_pyfunction_bound!(get::py_list_get, parent)?)?; parent.add_function(wrap_pyfunction_bound!(join::py_list_join, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + value_counts::py_list_value_counts, + parent + )?)?; parent.add_function(wrap_pyfunction_bound!(max::py_list_max, parent)?)?; parent.add_function(wrap_pyfunction_bound!(min::py_list_min, parent)?)?; diff --git a/src/daft-functions/src/list/value_counts.rs b/src/daft-functions/src/list/value_counts.rs new file mode 100644 index 0000000000..83607fc9d7 --- /dev/null +++ b/src/daft-functions/src/list/value_counts.rs @@ -0,0 +1,70 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::prelude::{DataType, Field, Schema, Series}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + python::PyExpr, + ExprRef, +}; +use pyo3::{pyfunction, PyResult}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +struct ListValueCountsFunction; + +#[typetag::serde] +impl ScalarUDF for ListValueCountsFunction { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "list_value_counts" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + let [data] = inputs else { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + + let data_field = data.to_field(schema)?; + + let DataType::List(inner_type) = &data_field.dtype else { + return Err(DaftError::TypeError(format!( + "Expected list, got {}", + data_field.dtype + ))); + }; + + let map_type = DataType::Map { + key: inner_type.clone(), + value: Box::new(DataType::UInt64), + }; + + Ok(Field::new(data_field.name, map_type)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + let [data] = inputs else { + return Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + + data.list_value_counts() + } +} + +pub fn list_value_counts(expr: ExprRef) -> ExprRef { + ScalarFunction::new(ListValueCountsFunction, vec![expr]).into() +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_value_counts")] +pub fn py_list_value_counts(expr: PyExpr) -> PyResult { + Ok(list_value_counts(expr.into()).into()) +} diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index 65cf8f808e..e7c4d5ec23 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -107,8 +107,11 @@ pub enum DataType { Struct(Vec), /// A nested [`DataType`] that is represented as List>. - #[display("Map[{_0}]")] - Map(Box), + #[display("Map[{key}: {value}]")] + Map { + key: Box, + value: Box, + }, /// Extension type. #[display("{_1}")] @@ -233,14 +236,31 @@ impl DataType { Self::List(field) => Ok(ArrowType::LargeList(Box::new( arrow2::datatypes::Field::new("item", field.to_arrow()?, true), ))), - Self::Map(field) => Ok(ArrowType::Map( - Box::new(arrow2::datatypes::Field::new( - "item", - field.to_arrow()?, - true, - )), - false, - )), + Self::Map { key, value } => { + let struct_type = ArrowType::Struct(vec![ + // We never allow null keys in maps for several reasons: + // 1. Null typically represents the absence of a value, which doesn't make sense for a key. + // 2. Null comparisons can be problematic (similar to how f64::NAN != f64::NAN). + // 3. It maintains consistency with common map implementations in arrow (no null keys). + // 4. It simplifies map operations + // + // This decision aligns with the thoughts of team members like Jay and Sammy, who argue that: + // - Nulls in keys could lead to unintuitive behavior + // - If users need to count or group by null values, they can use other constructs like + // group_by operations on non-map types, which offer more explicit control. + // + // By disallowing null keys, we encourage more robust data modeling practices and + // provide a clearer semantic meaning for map types in our system. + arrow2::datatypes::Field::new("key", key.to_arrow()?, true), + arrow2::datatypes::Field::new("value", value.to_arrow()?, true), + ]); + + // entries + let struct_field = + arrow2::datatypes::Field::new("entries", struct_type.clone(), true); + + Ok(ArrowType::Map(Box::new(struct_field), false)) + } Self::Struct(fields) => Ok({ let fields = fields .iter() @@ -288,7 +308,10 @@ impl DataType { FixedSizeList(child_dtype, size) => { FixedSizeList(Box::new(child_dtype.to_physical()), *size) } - Map(child_dtype) => List(Box::new(child_dtype.to_physical())), + Map { key, value } => List(Box::new(Struct(vec![ + Field::new("key", key.to_physical()), + Field::new("value", value.to_physical()), + ]))), Embedding(dtype, size) => FixedSizeList(Box::new(dtype.to_physical()), *size), Image(mode) => Struct(vec![ Field::new( @@ -328,20 +351,6 @@ impl DataType { } } - #[inline] - pub fn nested_dtype(&self) -> Option<&Self> { - match self { - Self::Map(dtype) - | Self::List(dtype) - | Self::FixedSizeList(dtype, _) - | Self::FixedShapeTensor(dtype, _) - | Self::SparseTensor(dtype) - | Self::FixedShapeSparseTensor(dtype, _) - | Self::Tensor(dtype) => Some(dtype), - _ => None, - } - } - #[inline] pub fn is_arrow(&self) -> bool { self.to_arrow().is_ok() @@ -350,21 +359,21 @@ impl DataType { #[inline] pub fn is_numeric(&self) -> bool { match self { - Self::Int8 - | Self::Int16 - | Self::Int32 - | Self::Int64 - | Self::Int128 - | Self::UInt8 - | Self::UInt16 - | Self::UInt32 - | Self::UInt64 - // DataType::Float16 - | Self::Float32 - | Self::Float64 => true, - Self::Extension(_, inner, _) => inner.is_numeric(), - _ => false - } + Self::Int8 + | Self::Int16 + | Self::Int32 + | Self::Int64 + | Self::Int128 + | Self::UInt8 + | Self::UInt16 + | Self::UInt32 + | Self::UInt64 + // DataType::Float16 + | Self::Float32 + | Self::Float64 => true, + Self::Extension(_, inner, _) => inner.is_numeric(), + _ => false + } } #[inline] @@ -453,7 +462,7 @@ impl DataType { #[inline] pub fn is_map(&self) -> bool { - matches!(self, Self::Map(..)) + matches!(self, Self::Map { .. }) } #[inline] @@ -576,7 +585,7 @@ impl DataType { | Self::FixedShapeTensor(..) | Self::SparseTensor(..) | Self::FixedShapeSparseTensor(..) - | Self::Map(..) + | Self::Map { .. } ) } @@ -590,7 +599,7 @@ impl DataType { let p: Self = self.to_physical(); matches!( p, - Self::List(..) | Self::FixedSizeList(..) | Self::Struct(..) | Self::Map(..) + Self::List(..) | Self::FixedSizeList(..) | Self::Struct(..) | Self::Map { .. } ) } @@ -638,7 +647,29 @@ impl From<&ArrowType> for DataType { ArrowType::FixedSizeList(field, size) => { Self::FixedSizeList(Box::new(field.as_ref().data_type().into()), *size) } - ArrowType::Map(field, ..) => Self::Map(Box::new(field.as_ref().data_type().into())), + ArrowType::Map(field, ..) => { + // todo: TryFrom in future? want in second pass maybe + + // field should be a struct + let ArrowType::Struct(fields) = &field.data_type else { + panic!("Map should have a struct as its key") + }; + + let [key, value] = fields.as_slice() else { + panic!("Map should have two fields") + }; + + let key = &key.data_type; + let value = &value.data_type; + + let key = Self::from(key); + let value = Self::from(value); + + let key = Box::new(key); + let value = Box::new(value); + + Self::Map { key, value } + } ArrowType::Struct(fields) => { let fields: Vec = fields.iter().map(|fld| fld.into()).collect(); Self::Struct(fields) diff --git a/src/daft-schema/src/python/datatype.rs b/src/daft-schema/src/python/datatype.rs index ceff5e18f3..edacbfbdad 100644 --- a/src/daft-schema/src/python/datatype.rs +++ b/src/daft-schema/src/python/datatype.rs @@ -209,10 +209,10 @@ impl PyDataType { #[staticmethod] pub fn map(key_type: Self, value_type: Self) -> PyResult { - Ok(DataType::Map(Box::new(DataType::Struct(vec![ - Field::new("key", key_type.dtype), - Field::new("value", value_type.dtype), - ]))) + Ok(DataType::Map { + key: Box::new(key_type.dtype), + value: Box::new(value_type.dtype), + } .into()) } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 2aefab9f96..9de1853696 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -937,7 +937,7 @@ impl SQLPlanner { invalid_operation_err!("Index must be a string literal") } } - DataType::Map(_) => Ok(daft_dsl::functions::map::get(expr, index)), + DataType::Map { .. } => Ok(daft_dsl::functions::map::get(expr, index)), dtype => { invalid_operation_err!("nested access on column with type: {}", dtype) } diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index df96daa373..d72ba7cb9c 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -71,7 +71,7 @@ impl ColumnRangeStatistics { // UNSUPPORTED TYPES: // Types that don't support comparisons and can't be used as ColumnRangeStatistics - DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, + DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map { .. } | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, #[cfg(feature = "python")] DataType::Python => false, } diff --git a/src/daft-table/src/repr_html.rs b/src/daft-table/src/repr_html.rs index 79ecaf063a..0e46bb80b2 100644 --- a/src/daft-table/src/repr_html.rs +++ b/src/daft-table/src/repr_html.rs @@ -102,7 +102,7 @@ pub fn html_value(s: &Series, idx: usize) -> String { let arr = s.struct_().unwrap(); arr.html_value(idx) } - DataType::Map(_) => { + DataType::Map { .. } => { let arr = s.map().unwrap(); arr.html_value(idx) } diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index d3727c2ac3..23e4674ba3 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -508,3 +508,61 @@ def test_repr_series_lit() -> None: s = lit(Series.from_pylist([1, 2, 3])) output = repr(s) assert output == "lit([1, 2, 3])" + + +def test_list_value_counts(): + # Create a MicroPartition with a list column + mp = MicroPartition.from_pydict({ + "list_col": [ + ["a", "b", "a", "c"], + ["b", "b", "c"], + ["a", "a", "a"], + [], + ["d", None, "d"] + ] + }) + + print("mp is ", mp) + + # mp = MicroPartition.from_pydict({ + # "list_col": [ + # ["a", "b", "a", "c"], + # ["b", "b", "c"], + # ["a", "a", "a"], + # [], + # ["d", "d"] + # ] + # }) + + # # Apply list_value_counts operation + result = mp.eval_expression_list([col("list_col").list.value_counts().alias("value_counts")]) + value_counts = result.to_pydict()["value_counts"] + + # Expected output + expected = [ + [("a", 2), ("b", 1), ("c", 1)], + [("b", 2), ("c", 1)], + [("a", 3)], + [], + [("d", 2)] + ] + + assert value_counts == expected + + # # Check the result + # value_counts = result.to_pydict()["value_counts"] + # print(value_counts) + # assert value_counts == expected + + # # Test with empty input + # empty_mp = MicroPartition.from_pydict({"list_col": []}) + # empty_result = empty_mp.eval_expression_list([col("list_col").list.value_counts().alias("value_counts")]) + # assert empty_result.to_pydict()["value_counts"] == [] + + # # Test with all None input + # none_mp = MicroPartition.from_pydict({"list_col": [None, None, None]}) + # none_result = none_mp.eval_expression_list([col("list_col").list.value_counts().alias("value_counts")]) + # assert none_result.to_pydict()["value_counts"] == [None, None, None] + + +