From 1f70823cfc71a3f66b8c6ebd6a7c15e0c0ea045e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 1 Feb 2024 11:26:37 -0500 Subject: [PATCH 1/7] Update minimum rust version to 1.72 (#8997) * Update minimum rust version to 1.72 --- .github/workflows/rust.yml | 6 +++++- Cargo.toml | 2 +- benchmarks/Cargo.toml | 2 +- datafusion-cli/Cargo.toml | 3 ++- datafusion/core/Cargo.toml | 5 ++++- datafusion/expr/src/expr.rs | 11 +++-------- datafusion/physical-expr/src/aggregate/hyperloglog.rs | 8 +++----- datafusion/proto/Cargo.toml | 3 ++- datafusion/proto/gen/Cargo.toml | 2 +- datafusion/substrait/Cargo.toml | 3 ++- datafusion/wasmtest/Cargo.toml | 2 +- docs/Cargo.toml | 2 +- 12 files changed, 26 insertions(+), 23 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 501b05c25d8e..c94137ebd1f9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -492,7 +492,11 @@ jobs: ./dev/update_config_docs.sh git diff --exit-code - # Verify MSRV for the crates which are directly used by other projects. + # Verify MSRV for the crates which are directly used by other projects: + # - datafusion + # - datafusion-substrait + # - datafusion-proto + # - datafusion-cli msrv: name: Verify MSRV (Min Supported Rust Version) runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index d56d37ad2b35..cccca0174113 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ homepage = "https://github.com/apache/arrow-datafusion" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/arrow-datafusion" -rust-version = "1.70" +rust-version = "1.72" version = "35.0.0" [workspace.dependencies] diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 50b79b4b0661..ced77c73f593 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -24,7 +24,7 @@ authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" license = "Apache-2.0" -rust-version = "1.70" +rust-version = { workspace = true } [features] ci = [] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 07ee65e3f6cd..79a1f0162e6a 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -25,7 +25,8 @@ keywords = ["arrow", "datafusion", "query", "sql"] license = "Apache-2.0" homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" -rust-version = "1.70" +# Specify MSRV here as `cargo msrv` doesn't support workspace version +rust-version = "1.72" readme = "README.md" [dependencies] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index f9a4c54b7dc6..2d795d0f8369 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -27,7 +27,10 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.70" +# Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with +# "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" +# https://github.com/foresterre/cargo-msrv/issues/590 +rust-version = "1.72" [lib] name = "datafusion" diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9da1f4bb4df7..0000f3df033a 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -33,7 +33,7 @@ use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; -use std::hash::{BuildHasher, Hash, Hasher}; +use std::hash::Hash; use std::str::FromStr; use std::sync::Arc; @@ -853,13 +853,8 @@ const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); impl PartialOrd for Expr { fn partial_cmp(&self, other: &Self) -> Option { - let mut hasher = SEED.build_hasher(); - self.hash(&mut hasher); - let s = hasher.finish(); - - let mut hasher = SEED.build_hasher(); - other.hash(&mut hasher); - let o = hasher.finish(); + let s = SEED.hash_one(self); + let o = SEED.hash_one(other); Some(s.cmp(&o)) } diff --git a/datafusion/physical-expr/src/aggregate/hyperloglog.rs b/datafusion/physical-expr/src/aggregate/hyperloglog.rs index a0d55ca71db1..657a7b9f7f21 100644 --- a/datafusion/physical-expr/src/aggregate/hyperloglog.rs +++ b/datafusion/physical-expr/src/aggregate/hyperloglog.rs @@ -34,8 +34,8 @@ //! //! This module also borrows some code structure from [pdatastructs.rs](https://github.com/crepererum/pdatastructs.rs/blob/3997ed50f6b6871c9e53c4c5e0f48f431405fc63/src/hyperloglog.rs). -use ahash::{AHasher, RandomState}; -use std::hash::{BuildHasher, Hash, Hasher}; +use ahash::RandomState; +use std::hash::Hash; use std::marker::PhantomData; /// The greater is P, the smaller the error. @@ -102,9 +102,7 @@ where /// reasonable performance. #[inline] fn hash_value(&self, obj: &T) -> u64 { - let mut hasher: AHasher = SEED.build_hasher(); - obj.hash(&mut hasher); - hasher.finish() + SEED.hash_one(obj) } /// Adds an element to the HyperLogLog. diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 2eaf25198734..cdd464f38a76 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -26,7 +26,8 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.70" +# Specify MSRV here as `cargo msrv` doesn't support workspace version +rust-version = "1.72" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 8b3f3f98a8a1..c80bd50af287 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.64" +rust-version = "1.72" authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index a4d18e0d35fd..38414fc5e67a 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -25,7 +25,8 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.70" +# Specify MSRV here as `cargo msrv` doesn't support workspace version +rust-version = "1.72" [dependencies] async-recursion = "1.0" diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 91af15a6ea62..c47dcf83c84b 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -25,7 +25,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.70" +rust-version = { workspace = true } [lib] crate-type = ["cdylib", "rlib"] diff --git a/docs/Cargo.toml b/docs/Cargo.toml index 3a8c90cae085..7eecd11df80b 100644 --- a/docs/Cargo.toml +++ b/docs/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.70" +rust-version = { workspace = true } [dependencies] datafusion = { path = "../datafusion/core", version = "35.0.0", default-features = false } From 402625f0353883d54aa68122a37b617179e908eb Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Thu, 1 Feb 2024 18:39:01 +0200 Subject: [PATCH 2/7] . (#9099) --- datafusion/physical-plan/src/filter.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 56a1b4e17821..362fa10efc9f 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -89,7 +89,7 @@ impl FilterExec { default_selectivity: u8, ) -> Result { if default_selectivity > 100 { - return plan_err!("Default flter selectivity needs to be less than 100"); + return plan_err!("Default filter selectivity needs to be less than 100"); } self.default_selectivity = default_selectivity; Ok(self) From a9fda5eb912ab4279ed5bcd6beed4d8f1ae70a6e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 1 Feb 2024 11:39:53 -0500 Subject: [PATCH 3/7] Update InfluxDB links in Known Users (#9092) --- docs/source/user-guide/introduction.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 3229d0d50591..ae2684699726 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -104,7 +104,7 @@ Here are some active projects using DataFusion: - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. - [HoraeDB](https://github.com/apache/incubator-horaedb) Distributed Time-Series Database -- [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database +- [InfluxDB](https://github.com/influxdata/influxdb) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline - [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. - [Lance](https://github.com/lancedb/lance) Modern columnar data format for ML From 038763c7db0961e183b481f0d890abe8da008562 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Fri, 2 Feb 2024 00:42:41 +0800 Subject: [PATCH 4/7] Support `FixedSizeList` type coercion (#8902) * Add support for parsing FixedSizeList type * support cast fixedsizelist from list * support FixedSizeList type coercion * fix conflict * add test for [] * support fixedsizelist for functinos with array-element or element-array argument * add tests for array_element * add tests for array_remove * add tests for array_remove_n * add tests for array_has * fix comment * fix comment * add tests for array_positoins * test chore * test chore * remove useless logic * refatctor coerced_type_with_base_type_only function * add null test for array_has * refactor: put fixedsizelist in coerce_arguments_for_fun * chore * refactor type signature * add array_and_index function * put all type coercion in coerce_arguments_for_signature * add comment Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/scalar.rs | 5 +- datafusion/common/src/utils.rs | 54 +- datafusion/expr/src/built_in_function.rs | 46 +- datafusion/expr/src/signature.rs | 57 +- .../expr/src/type_coercion/functions.rs | 62 +- .../optimizer/src/analyzer/type_coercion.rs | 1 - datafusion/sqllogictest/test_files/array.slt | 604 +++++++++++++++++- .../sqllogictest/test_files/arrow_typeof.slt | 2 +- 8 files changed, 741 insertions(+), 90 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 2f9e374bd7f4..36b00b65e285 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1483,7 +1483,9 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } - DataType::List(_) | DataType::LargeList(_) => build_list_array(scalars)?, + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) => build_list_array(scalars)?, DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec> = @@ -1595,7 +1597,6 @@ impl ScalarValue { | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) | DataType::Duration(_) - | DataType::FixedSizeList(_, _) | DataType::Union(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => { diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index d21bd464f850..a12b71c17dcf 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -440,9 +440,9 @@ pub fn arrays_into_list_array( /// ``` pub fn base_type(data_type: &DataType) -> DataType { match data_type { - DataType::List(field) | DataType::LargeList(field) => { - base_type(field.data_type()) - } + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => base_type(field.data_type()), _ => data_type.to_owned(), } } @@ -464,31 +464,23 @@ pub fn coerced_type_with_base_type_only( base_type: &DataType, ) -> DataType { match data_type { - DataType::List(field) => { - let data_type = match field.data_type() { - DataType::List(_) => { - coerced_type_with_base_type_only(field.data_type(), base_type) - } - _ => base_type.to_owned(), - }; + DataType::List(field) | DataType::FixedSizeList(field, _) => { + let field_type = + coerced_type_with_base_type_only(field.data_type(), base_type); DataType::List(Arc::new(Field::new( field.name(), - data_type, + field_type, field.is_nullable(), ))) } DataType::LargeList(field) => { - let data_type = match field.data_type() { - DataType::LargeList(_) => { - coerced_type_with_base_type_only(field.data_type(), base_type) - } - _ => base_type.to_owned(), - }; + let field_type = + coerced_type_with_base_type_only(field.data_type(), base_type); DataType::LargeList(Arc::new(Field::new( field.name(), - data_type, + field_type, field.is_nullable(), ))) } @@ -497,6 +489,32 @@ pub fn coerced_type_with_base_type_only( } } +/// Recursively coerce and `FixedSizeList` elements to `List` +pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType { + match data_type { + DataType::List(field) | DataType::FixedSizeList(field, _) => { + let field_type = coerced_fixed_size_list_to_list(field.data_type()); + + DataType::List(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } + DataType::LargeList(field) => { + let field_type = coerced_fixed_size_list_to_list(field.data_type()); + + DataType::LargeList(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } + + _ => data_type.clone(), + } +} + /// Compute the number of dimensions in a list data type. pub fn list_ndims(data_type: &DataType) -> u64 { match data_type { diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index b7bb17c86be7..20b7df46e387 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -599,10 +599,11 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { - List(field) => Ok(field.data_type().clone()), - LargeList(field) => Ok(field.data_type().clone()), + List(field) + | LargeList(field) + | FixedSizeList(field, _) => Ok(field.data_type().clone()), _ => plan_err!( - "The {self} function can only accept list or largelist as the first argument" + "The {self} function can only accept List, LargeList or FixedSizeList as the first argument" ), }, BuiltinScalarFunction::ArrayLength => Ok(UInt64), @@ -922,10 +923,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArraySort => { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::ArrayAppend => Signature { - type_signature: ArrayAndElement, - volatility: self.volatility(), - }, + BuiltinScalarFunction::ArrayAppend => { + Signature::array_and_element(self.volatility()) + } BuiltinScalarFunction::MakeArray => { // 0 or more arguments of arbitrary type Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility()) @@ -937,12 +937,17 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()), - BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayElement => { + Signature::array_and_index(self.volatility()) + } BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()), BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()), - BuiltinScalarFunction::ArrayHasAll - | BuiltinScalarFunction::ArrayHasAny - | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny => { + Signature::any(2, self.volatility()) + } + BuiltinScalarFunction::ArrayHas => { + Signature::array_and_element(self.volatility()) + } BuiltinScalarFunction::ArrayLength => { Signature::variadic_any(self.volatility()) } @@ -951,15 +956,20 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayPosition => { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature { - type_signature: ElementAndArray, - volatility: self.volatility(), - }, + BuiltinScalarFunction::ArrayPositions => { + Signature::array_and_element(self.volatility()) + } + BuiltinScalarFunction::ArrayPrepend => { + Signature::element_and_array(self.volatility()) + } BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRemove => { + Signature::array_and_element(self.volatility()) + } BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), - BuiltinScalarFunction::ArrayRemoveAll => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRemoveAll => { + Signature::array_and_element(self.volatility()) + } BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()), BuiltinScalarFunction::ArrayReplaceAll => { diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 729131bd95e1..48f4c996cb5d 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -116,6 +116,12 @@ pub enum TypeSignature { /// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), + /// Specifies Signatures for array functions + ArraySignature(ArrayFunctionSignature), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ArrayFunctionSignature { /// Specialized Signature for ArrayAppend and similar functions /// The first argument should be List/LargeList, and the second argument should be non-list or list. /// The second argument's list dimension should be one dimension less than the first argument's list dimension. @@ -126,6 +132,23 @@ pub enum TypeSignature { /// The first argument should be non-list or list, and the second argument should be List/LargeList. /// The first argument's list dimension should be one dimension less than the second argument's list dimension. ElementAndArray, + ArrayAndIndex, +} + +impl std::fmt::Display for ArrayFunctionSignature { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ArrayFunctionSignature::ArrayAndElement => { + write!(f, "array, element") + } + ArrayFunctionSignature::ElementAndArray => { + write!(f, "element, array") + } + ArrayFunctionSignature::ArrayAndIndex => { + write!(f, "array, index") + } + } + } } impl TypeSignature { @@ -156,11 +179,8 @@ impl TypeSignature { TypeSignature::OneOf(sigs) => { sigs.iter().flat_map(|s| s.to_string_repr()).collect() } - TypeSignature::ArrayAndElement => { - vec!["ArrayAndElement(List, T)".to_string()] - } - TypeSignature::ElementAndArray => { - vec!["ElementAndArray(T, List)".to_string()] + TypeSignature::ArraySignature(array_signature) => { + vec![array_signature.to_string()] } } } @@ -263,6 +283,33 @@ impl Signature { volatility, } } + /// Specialized Signature for ArrayAppend and similar functions + pub fn array_and_element(volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::ArrayAndElement, + ), + volatility, + } + } + /// Specialized Signature for ArrayPrepend and similar functions + pub fn element_and_array(volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::ElementAndArray, + ), + volatility, + } + } + /// Specialized Signature for ArrayElement and similar functions + pub fn array_and_index(volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::ArrayAndIndex, + ), + volatility, + } + } } /// Monotonicity of the `ScalarFunctionExpr` with respect to its arguments. diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 63908d539bd0..806fdaaa5246 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::signature::TIMEZONE_WILDCARD; +use crate::signature::{ArrayFunctionSignature, TIMEZONE_WILDCARD}; use crate::{Signature, TypeSignature}; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::utils::list_ndims; -use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; +use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; +use datafusion_common::{ + internal_datafusion_err, internal_err, plan_err, DataFusionError, Result, +}; use super::binary::comparison_coercion; @@ -48,7 +50,6 @@ pub fn data_types( ); } } - let valid_types = get_valid_types(&signature.type_signature, current_types)?; if valid_types @@ -104,12 +105,11 @@ fn get_valid_types( let elem_base_type = datafusion_common::utils::base_type(elem_type); let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); - if new_base_type.is_none() { - return internal_err!( + let new_base_type = new_base_type.ok_or_else(|| { + internal_datafusion_err!( "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." - ); - } - let new_base_type = new_base_type.unwrap(); + ) + })?; let array_type = datafusion_common::utils::coerced_type_with_base_type_only( array_type, @@ -117,10 +117,12 @@ fn get_valid_types( ); match array_type { - DataType::List(ref field) | DataType::LargeList(ref field) => { + DataType::List(ref field) + | DataType::LargeList(ref field) + | DataType::FixedSizeList(ref field, _) => { let elem_type = field.data_type(); if is_append { - Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]) + Ok(vec![vec![array_type.clone(), elem_type.clone()]]) } else { Ok(vec![vec![elem_type.to_owned(), array_type.clone()]]) } @@ -128,6 +130,23 @@ fn get_valid_types( _ => Ok(vec![vec![]]), } } + fn array_and_index(current_types: &[DataType]) -> Result>> { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let array_type = ¤t_types[0]; + + match array_type { + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) => { + let array_type = coerced_fixed_size_list_to_list(array_type); + Ok(vec![vec![array_type, DataType::Int64]]) + } + _ => Ok(vec![vec![]]), + } + } let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() @@ -160,12 +179,19 @@ fn get_valid_types( } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], - TypeSignature::ArrayAndElement => { - return array_append_or_prepend_valid_types(current_types, true) - } - TypeSignature::ElementAndArray => { - return array_append_or_prepend_valid_types(current_types, false) - } + TypeSignature::ArraySignature(ref function_signature) => match function_signature + { + ArrayFunctionSignature::ArrayAndElement => { + return array_append_or_prepend_valid_types(current_types, true) + } + ArrayFunctionSignature::ArrayAndIndex => { + return array_and_index(current_types) + } + ArrayFunctionSignature::ElementAndArray => { + return array_append_or_prepend_valid_types(current_types, false) + } + }, + TypeSignature::Any(number) => { if current_types.len() != *number { return plan_err!( @@ -311,6 +337,8 @@ fn coerced_from<'a>( Utf8 | LargeUtf8 => Some(type_into.clone()), Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), + List(_) if matches!(type_from, FixedSizeList(_, _)) => Some(type_into.clone()), + // Only accept list and largelist with the same number of dimensions unless the type is Null. // List or LargeList with different dimensions should be handled in TypeSignature or other places before this. List(_) | LargeList(_) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 8710249e1294..d804edb0c52f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -589,7 +589,6 @@ fn coerce_arguments_for_fun( if expressions.is_empty() { return Ok(vec![]); } - let mut expressions: Vec = expressions.to_vec(); // Cast Fixedsizelist to List for array functions diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index e6a8181be1ac..4fdc428d7a9c 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -77,6 +77,19 @@ AS FROM arrays ; +#TODO: create FixedSizeList with NULL column +statement ok +CREATE TABLE fixed_size_arrays +AS VALUES + (arrow_cast(make_array(make_array(NULL, 2),make_array(3, NULL)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(1.1, 2.2, 3.3), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('L', 'o', 'r', 'e', 'm'), 'FixedSizeList(5, Utf8)')), + (arrow_cast(make_array(make_array(3, 4),make_array(5, 6)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(NULL, 5.5, 6.6), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('i', 'p', NULL, 'u', 'm'), 'FixedSizeList(5, Utf8)')), + (arrow_cast(make_array(make_array(5, 6),make_array(7, 8)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(7.7, 8.8, 9.9), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('d', NULL, 'l', 'o', 'r'), 'FixedSizeList(5, Utf8)')), + (arrow_cast(make_array(make_array(7, NULL),make_array(9, 10)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(10.1, NULL, 12.2), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('s', 'i', 't', 'a', 'b'), 'FixedSizeList(5, Utf8)')), + (arrow_cast(make_array(make_array(7, NULL),make_array(9, 10)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(13.3, 14.4, 15.5), 'FixedSizeList(3, Float64)'), arrow_cast(make_array('a', 'm', 'e', 't', 'x'), 'FixedSizeList(5, Utf8)')), + (arrow_cast(make_array(make_array(11, 12),make_array(13, 14)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(13.3, 14.4, 15.5), 'FixedSizeList(3, Float64)'), arrow_cast(make_array(',','a','b','c','d'), 'FixedSizeList(5, Utf8)')), + (arrow_cast(make_array(make_array(15, 16),make_array(NULL, 18)), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(16.6, 17.7, 18.8), 'FixedSizeList(3, Float64)'), arrow_cast(make_array(',','a','b','c','d'), 'FixedSizeList(5, Utf8)')) +; + statement ok CREATE TABLE slices AS VALUES @@ -89,6 +102,17 @@ AS VALUES (make_array(51, 52, NULL, 54, 55, 56, 57, 58, 59, 60), 5, NULL) ; +statement ok +CREATE TABLE fixed_slices +AS VALUES + (arrow_cast(make_array(NULL, 2, 3, 4, 5, 6, 7, 8, 9, 10), 'FixedSizeList(10, Int64)'), 1, 1), + (arrow_cast(make_array(11, 12, 13, 14, 15, 16, 17, 18, NULL, 20), 'FixedSizeList(10, Int64)'), 2, -4), + (arrow_cast(make_array(21, 22, 23, NULL, 25, 26, 27, 28, 29, 30), 'FixedSizeList(10, Int64)'), 0, 0), + (arrow_cast(make_array(31, 32, 33, 34, 35, NULL, 37, 38, 39, 40), 'FixedSizeList(10, Int64)'), -4, -7), + (arrow_cast(make_array(41, 42, 43, 44, 45, 46, 47, 48, 49, 50), 'FixedSizeList(10, Int64)'), NULL, 6), + (arrow_cast(make_array(51, 52, NULL, 54, 55, 56, 57, 58, 59, 60),'FixedSizeList(10, Int64)'), 5, NULL) +; + statement ok CREATE TABLE arrayspop AS VALUES @@ -119,6 +143,13 @@ AS FROM nested_arrays ; +statement ok +CREATE TABLE fixed_size_nested_arrays +AS VALUES + (arrow_cast(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), 'FixedSizeList(6, List(Int64))'), arrow_cast(make_array(7, 8, 9), 'FixedSizeList(3, Int64)'), 2, arrow_cast(make_array([[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array(11, 12, 13), 'FixedSizeList(3, Int64)')), + (arrow_cast(make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), 'FixedSizeList(6, List(Int64))'), arrow_cast(make_array(10, 11, 12), 'FixedSizeList(3, Int64)'), 3, arrow_cast(make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array(121, 131, 141), 'FixedSizeList(3, Int64)')) +; + statement ok CREATE TABLE arrays_values AS VALUES @@ -178,6 +209,13 @@ AS VALUES (make_array(3, 4, 5), 2, make_array(1,2,3,4), make_array(2,5), make_array(2,4,6), make_array(1,3,5)) ; +statement ok +CREATE TABLE fixed_size_array_has_table_1D +AS VALUES + (arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 1, arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), arrow_cast(make_array(1,3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array(1,3,5), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(2, 4, 6, 8, 1, 3, 5), 'FixedSizeList(7, Int64)')), + (arrow_cast(make_array(3, 4, 5), 'FixedSizeList(3, Int64)'), 2, arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), arrow_cast(make_array(2,5), 'FixedSizeList(2, Int64)'), arrow_cast(make_array(2,4,6), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 3, 5, 7, 9, 11, 13), 'FixedSizeList(7, Int64)')) +; + statement ok CREATE TABLE array_has_table_1D_Float AS VALUES @@ -185,6 +223,13 @@ AS VALUES (make_array(3.0, 4.0, 5.0), 2.0, make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) ; +statement ok +CREATE TABLE fixed_size_array_has_table_1D_Float +AS VALUES + (arrow_cast(make_array(1.0, 2.0, 3.0), 'FixedSizeList(3, Float64)'), 1.0, arrow_cast(make_array(1.0, 2.0, 3.0, 4.0), 'FixedSizeList(4, Float64)'), arrow_cast(make_array(1.0,3.0), 'FixedSizeList(2, Float64)'), arrow_cast(make_array(1.11, 2.22), 'FixedSizeList(2, Float64)'), arrow_cast(make_array(2.22, 3.33), 'FixedSizeList(2, Float64)')), + (arrow_cast(make_array(3.0, 4.0, 5.0), 'FixedSizeList(3, Float64)'), 2.0, arrow_cast(make_array(1.0, 2.0, 3.0, 4.0), 'FixedSizeList(4, Float64)'), arrow_cast(make_array(2.0,5.0), 'FixedSizeList(2, Float64)'), arrow_cast(make_array(2.22, 1.11), 'FixedSizeList(2, Float64)'), arrow_cast(make_array(1.11, 3.33), 'FixedSizeList(2, Float64)')) +; + statement ok CREATE TABLE array_has_table_1D_Boolean AS VALUES @@ -192,6 +237,13 @@ AS VALUES (make_array(false, false, false), false, make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) ; +statement ok +CREATE TABLE fixed_size_array_has_table_1D_Boolean +AS VALUES + (arrow_cast(make_array(true, true, true), 'FixedSizeList(3, Boolean)'), false, arrow_cast(make_array(true, true, false, true, false), 'FixedSizeList(5, Boolean)'), arrow_cast(make_array(true, false, true), 'FixedSizeList(3, Boolean)'), arrow_cast(make_array(false, true), 'FixedSizeList(2, Boolean)'), arrow_cast(make_array(true, false, true), 'FixedSizeList(3, Boolean)')), + (arrow_cast(make_array(false, false, false), 'FixedSizeList(3, Boolean)'), false, arrow_cast(make_array(true, false, true, true, false), 'FixedSizeList(5, Boolean)'), arrow_cast(make_array(true, true, false), 'FixedSizeList(3, Boolean)'), arrow_cast(make_array(true, true), 'FixedSizeList(2, Boolean)'), arrow_cast(make_array(false,false,true), 'FixedSizeList(3, Boolean)')) +; + statement ok CREATE TABLE array_has_table_1D_UTF8 AS VALUES @@ -199,6 +251,13 @@ AS VALUES (make_array('a', 'bc', 'def'), 'defg', make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) ; +statement ok +CREATE TABLE fixed_size_array_has_table_1D_UTF8 +AS VALUES + (arrow_cast(make_array('a', 'bc', 'def'), 'FixedSizeList(3, Utf8)'), 'bc', arrow_cast(make_array('datafusion', 'rust', 'arrow'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array('rust', 'arrow', 'datafusion', 'rust'), 'FixedSizeList(4, Utf8)'), arrow_cast(make_array('rust', 'arrow', 'python'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array('data', 'fusion', 'rust'), 'FixedSizeList(3, Utf8)')), + (arrow_cast(make_array('a', 'bc', 'def'), 'FixedSizeList(3, Utf8)'), 'defg', arrow_cast(make_array('datafusion', 'rust', 'arrow'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array('datafusion', 'rust', 'arrow', 'python'), 'FixedSizeList(4, Utf8)'), arrow_cast(make_array('rust', 'arrow', 'python'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array('datafusion', 'rust', 'arrow'), 'FixedSizeList(3, Utf8)')) +; + statement ok CREATE TABLE array_has_table_2D AS VALUES @@ -206,6 +265,13 @@ AS VALUES (make_array([3,4], [5]), make_array(5), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) ; +statement ok +CREATE TABLE fixed_size_array_has_table_2D +AS VALUES + (arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(1,3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array([1,2,3], [4,5], [6,7]), 'FixedSizeList(3, List(Int64))'), arrow_cast(make_array([4,5], [6,7], [1,2]), 'FixedSizeList(3, List(Int64))')), + (arrow_cast(make_array([3,4], [5]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(5, 3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array([1,2,3,4], [5,6,7], [8,9,10]), 'FixedSizeList(3, List(Int64))'), arrow_cast(make_array([1,2,3], [5,6,7], [8,9,10]), 'FixedSizeList(3, List(Int64))')) +; + statement ok CREATE TABLE array_has_table_2D_float AS VALUES @@ -213,6 +279,13 @@ AS VALUES (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) ; +statement ok +CREATE TABLE fixed_size_array_has_table_2D_Float +AS VALUES + (arrow_cast(make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), 'FixedSizeList(3, List(Float64))'), arrow_cast(make_array([1.1, 2.2], [3.3], [4.4]), 'FixedSizeList(3, List(Float64))')), + (arrow_cast(make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), 'FixedSizeList(3, List(Float64))'), arrow_cast(make_array([1.0], [1.1, 2.2], [3.3]), 'FixedSizeList(3, List(Float64))')) +; + statement ok CREATE TABLE array_has_table_3D AS VALUES @@ -225,6 +298,18 @@ AS VALUES (make_array([[1], [2]], [[2], [3]]), make_array([1], [2])) ; +statement ok +CREATE TABLE fixed_size_array_has_table_3D +AS VALUES + (arrow_cast(make_array([[1,2]], [[3, 4]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')), + (arrow_cast(make_array([[1,2]], [[4, 4]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1,2], [3, 4]), 'FixedSizeList(2, List(Int64))')), + (arrow_cast(make_array([[1,2]], [[4, 4]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1,2,3], [1]), 'FixedSizeList(2, List(Int64))')), + (arrow_cast(make_array([[1], [2]], []), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([2], [3]), 'FixedSizeList(2, List(Int64))')), + (arrow_cast(make_array([[1], [2]], []), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')), + (arrow_cast(make_array([[1], [2]], [[2], [3]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')), + (arrow_cast(make_array([[1], [2]], [[2], [3]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')) +; + statement ok CREATE TABLE array_distinct_table_1D AS VALUES @@ -407,6 +492,15 @@ AS SELECT FROM arrays_values_without_nulls ; +statement ok +CREATE TABLE fixed_size_arrays_values_without_nulls +AS VALUES + (arrow_cast(make_array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 'FixedSizeList(10, Int64)'), 1, 1, ',', [2,3]), + (arrow_cast(make_array(11, 12, 13, 14, 15, 16, 17, 18, 19, 20), 'FixedSizeList(10, Int64)'), 12, 2, '.', [4,5]), + (arrow_cast(make_array(21, 22, 23, 24, 25, 26, 27, 28, 29, 30), 'FixedSizeList(10, Int64)'), 23, 3, '-', [6,7]), + (arrow_cast(make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 'FixedSizeList(10, Int64)'), 34, 4, 'ok', [8,9]) +; + statement ok CREATE TABLE arrays_range AS VALUES @@ -434,6 +528,15 @@ AS FROM arrays_with_repeating_elements ; +statement ok +CREATE TABLE fixed_arrays_with_repeating_elements +AS VALUES + (arrow_cast(make_array(1, 2, 1, 3, 2, 2, 1, 3, 2, 3), 'FixedSizeList(10, Int64)'), 2, 4, 3), + (arrow_cast(make_array(4, 4, 5, 5, 6, 5, 5, 5, 4, 4), 'FixedSizeList(10, Int64)'), 4, 7, 2), + (arrow_cast(make_array(7, 7, 7, 8, 7, 9, 7, 8, 7, 7), 'FixedSizeList(10, Int64)'), 7, 10, 5), + (arrow_cast(make_array(10, 11, 12, 10, 11, 12, 10, 11, 12, 10), 'FixedSizeList(10, Int64)'), 10, 13, 10) +; + statement ok CREATE TABLE nested_arrays_with_repeating_elements AS VALUES @@ -454,6 +557,15 @@ AS FROM nested_arrays_with_repeating_elements ; +statement ok +CREATE TABLE fixed_size_nested_arrays_with_repeating_elements +AS VALUES + (arrow_cast(make_array([1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(10, List(Int64))'), [4, 5, 6], [10, 11, 12], 3), + (arrow_cast(make_array([10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]), 'FixedSizeList(10, List(Int64))'), [10, 11, 12], [19, 20, 21], 2), + (arrow_cast(make_array([19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]), 'FixedSizeList(10, List(Int64))'), [19, 20, 21], [28, 29, 30], 5), + (arrow_cast(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), 'FixedSizeList(10, List(Int64))'), [28, 29, 30], [28, 29, 30], 10) +; + # Array literal ## boolean coercion is not supported @@ -941,9 +1053,17 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) # array_element error -query error DataFusion error: Error during planning: The array_element function can only accept list or largelist as the first argument +query error DataFusion error: Error during planning: No function matches the given name and argument types 'array_element\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarray_element\(array, index\) select array_element(1, 2); +# array_element with null +query I +select array_element([1, 2], NULL); +---- +NULL + +query error +select array_element(NULL, 2); # array_element scalar function #1 (with positive index) query IT @@ -956,6 +1076,11 @@ select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), ---- 2 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 3); +---- +2 l + # array_element scalar function #2 (with positive index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11); @@ -967,6 +1092,11 @@ select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 7), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 11); +---- +NULL NULL + # array_element scalar function #3 (with zero) query IT select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0); @@ -978,12 +1108,26 @@ select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 0), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 0); +---- +NULL NULL + # array_element scalar function #4 (with NULL) -query error +query IT select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); +---- +NULL NULL -query error +query IT select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); +---- +NULL NULL + +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), NULL); +---- +NULL NULL # array_element scalar function #5 (with negative index) query IT @@ -996,6 +1140,11 @@ select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), ---- 4 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), -2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), -3); +---- +4 l + # array_element scalar function #6 (with negative index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7); @@ -1007,6 +1156,11 @@ select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), -11), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), -7); +---- +NULL NULL + # array_element scalar function #7 (nested array) query ? select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1); @@ -1018,6 +1172,11 @@ select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array ---- [1, 2, 3, 4, 5] +query ? +select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'FixedSizeList(2, List(Int64))'), 1); +---- +[1, 2, 3, 4, 5] + # array_extract scalar function #8 (function alias `array_element`) query IT select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); @@ -1029,6 +1188,11 @@ select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), ---- 2 l +query IT +select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 3); +---- +2 l + # list_element scalar function #9 (function alias `array_element`) query IT select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); @@ -1040,6 +1204,11 @@ select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2 ---- 2 l +query IT +select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 2), list_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 3); +---- +2 l + # list_extract scalar function #10 (function alias `array_element`) query IT select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); @@ -1047,7 +1216,12 @@ select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 2 l query IT -select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), list_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + +query IT +select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 2), list_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 3); ---- 2 l @@ -1074,6 +1248,16 @@ NULL NULL 55 +query I +select array_element(column1, column2) from fixed_slices; +---- +NULL +12 +NULL +37 +NULL +55 + # array_element with columns and scalars query II select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices; @@ -1097,6 +1281,16 @@ NULL 23 NULL 43 5 NULL +query II +select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from fixed_slices; +---- +1 3 +2 13 +NULL 23 +2 33 +NULL 43 +5 NULL + ## array_pop_back (aliases: `list_pop_back`) # array_pop_back scalar function #1 @@ -2354,6 +2548,12 @@ select array_concat(make_array(column3), column1, column2) from arrays_values_v2 ## array_position (aliases: `list_position`, `array_indexof`, `list_indexof`) +## array_position with NULL (follow PostgreSQL) +#query I +#select array_position([1, 2, 3, 4, 5], null), array_position(NULL, 1); +#---- +#NULL NULL + # array_position scalar function #1 query III select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, 4, 5], 5), array_position([1, 1, 1], 1); @@ -2488,6 +2688,12 @@ NULL 1 NULL ## array_positions (aliases: `list_positions`) +# array_position with NULL (follow PostgreSQL) +query ? +select array_positions([1, 2, 3, 4, 5], null); +---- +[] + # array_positions scalar function #1 query ??? select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3, 4, 5], 5), array_positions([1, 1, 1], 1); @@ -2499,6 +2705,11 @@ select array_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ---- [3, 4] [5] [1, 2, 3] +query ??? +select array_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'FixedSizeList(5, Utf8)'), 'l'), array_positions(arrow_cast([1, 2, 3, 4, 5], 'FixedSizeList(5, Int64)'), 5), array_positions(arrow_cast([1, 1, 1], 'FixedSizeList(3, Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions scalar function #2 (element is list) query ? select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), [2, 1, 3]); @@ -2510,6 +2721,11 @@ select array_positions(arrow_cast(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2 ---- [2, 4] +query ? +select array_positions(arrow_cast(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), 'FixedSizeList(5, List(Int64))'), [2, 1, 3]); +---- +[2, 4] + # list_positions scalar function #3 (function alias `array_positions`) query ??? select list_positions(['h', 'e', 'l', 'l', 'o'], 'l'), list_positions([1, 2, 3, 4, 5], 5), list_positions([1, 1, 1], 1); @@ -2521,6 +2737,13 @@ select list_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ---- [3, 4] [5] [1, 2, 3] +query ??? +select list_positions(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'FixedSizeList(5, Utf8)'), 'l'), + list_positions(arrow_cast([1, 2, 3, 4, 5], 'FixedSizeList(5, Int64)'), 5), + list_positions(arrow_cast([1, 1, 1], 'FixedSizeList(3, Int64)'), 1); +---- +[3, 4] [5] [1, 2, 3] + # array_positions with columns #1 query ? select array_positions(column1, column2) from arrays_values_without_nulls; @@ -2538,6 +2761,14 @@ select array_positions(arrow_cast(column1, 'LargeList(Int64)'), column2) from ar [3] [4] +query ? +select array_positions(arrow_cast(column1, 'LargeList(Int64)'), column2) from fixed_size_arrays_values_without_nulls; +---- +[1] +[2] +[3] +[4] + # array_positions with columns #2 (element is list) query ? select array_positions(column1, column2) from nested_arrays; @@ -2551,6 +2782,12 @@ select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), column2) f [3] [2, 5] +query ? +select array_positions(column1, column2) from fixed_size_nested_arrays; +---- +[3] +[2, 5] + # array_positions with columns and scalars #1 query ?? select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; @@ -2568,6 +2805,14 @@ select array_positions(arrow_cast(column1, 'LargeList(Int64)'), 4), array_positi [] [3] [] [] +query ?? +select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from fixed_size_arrays_values_without_nulls; +---- +[4] [1] +[] [] +[] [3] +[] [] + # array_positions with columns and scalars #2 (element is list) query ?? select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from nested_arrays; @@ -2581,6 +2826,12 @@ select array_positions(arrow_cast(column1, 'LargeList(List(Int64))'), make_array [6] [] [1] [] +query ?? +select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from fixed_size_nested_arrays; +---- +[6] [] +[1] [] + ## array_replace (aliases: `list_replace`) # array_replace scalar function #1 @@ -3515,6 +3766,13 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0, ---- [1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] +query ??? +select array_remove(arrow_cast(make_array(1, 2, 2, 1, 1), 'FixedSizeList(5, Int64)'), 2), + array_remove(arrow_cast(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 'FixedSizeList(5, Float64)'), 1.0), + array_remove(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 'l'); +---- +[1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] + query ??? select array_remove(make_array(1, null, 2, 3), 2), @@ -3523,11 +3781,20 @@ select ---- [1, , 3] [, 2.2, 3.3] [, bc] -# TODO: https://github.com/apache/arrow-datafusion/issues/7142 -# query -# select -# array_remove(make_array(1, null, 2), null), -# array_remove(make_array(1, null, 2, null), null); +query ??? +select + array_remove(arrow_cast(make_array(1, null, 2, 3), 'FixedSizeList(4, Int64)'), 2), + array_remove(arrow_cast(make_array(1.1, null, 2.2, 3.3), 'FixedSizeList(4, Float64)'), 1.1), + array_remove(arrow_cast(make_array('a', null, 'bc'), 'FixedSizeList(3, Utf8)'), 'a'); +---- +[1, , 3] [, 2.2, 3.3] [, bc] + +query ?? +select + array_remove(make_array(1, null, 2), null), + array_remove(make_array(1, null, 2, null), null); +---- +[1, 2] [1, 2, ] # array_remove scalar function #2 (element is list) query ?? @@ -3535,12 +3802,24 @@ select array_remove(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8 ---- [[1, 2, 3], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [2, 3, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select array_remove(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, List(Int64))'), [4, 5, 6]), + array_remove(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, List(Int64))'), [2, 3, 4]); +---- +[[1, 2, 3], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + # list_remove scalar function #3 (function alias `array_remove`) query ??? select list_remove(make_array(1, 2, 2, 1, 1), 2), list_remove(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0), list_remove(make_array('h', 'e', 'l', 'l', 'o'), 'l'); ---- [1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] +query ?? +select list_remove(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, List(Int64))'), [4, 5, 6]), + list_remove(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, List(Int64))'), [2, 3, 4]); +---- +[[1, 2, 3], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + # array_remove scalar function with columns #1 query ? select array_remove(column1, column2) from arrays_with_repeating_elements; @@ -3550,6 +3829,14 @@ select array_remove(column1, column2) from arrays_with_repeating_elements; [7, 7, 8, 7, 9, 7, 8, 7, 7] [11, 12, 10, 11, 12, 10, 11, 12, 10] +query ? +select array_remove(column1, column2) from fixed_arrays_with_repeating_elements; +---- +[1, 1, 3, 2, 2, 1, 3, 2, 3] +[4, 5, 5, 6, 5, 5, 5, 4, 4] +[7, 7, 8, 7, 9, 7, 8, 7, 7] +[11, 12, 10, 11, 12, 10, 11, 12, 10] + # array_remove scalar function with columns #2 (element is list) query ? select array_remove(column1, column2) from nested_arrays_with_repeating_elements; @@ -3559,6 +3846,14 @@ select array_remove(column1, column2) from nested_arrays_with_repeating_elements [[19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ? +select array_remove(column1, column2) from fixed_size_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + # array_remove scalar function with columns and scalars #1 query ?? select array_remove(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2), array_remove(column1, 1) from arrays_with_repeating_elements; @@ -3568,9 +3863,27 @@ select array_remove(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2), a [1, 2, 2, 4, 5, 4, 4, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [1, 2, 2, 4, 5, 4, 4, 7, 7, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] +query ?? +select array_remove(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2), array_remove(column1, 1) from fixed_arrays_with_repeating_elements; +---- +[1, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [2, 1, 3, 2, 2, 1, 3, 2, 3] +[1, 2, 2, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] + # array_remove scalar function with columns and scalars #2 (element is list) query ?? -select array_remove(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2), array_remove(column1, make_array(1, 2, 3)) from nested_arrays_with_repeating_elements; +select array_remove(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2), + array_remove(column1, make_array(1, 2, 3)) from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +query ?? +select array_remove(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2), + array_remove(column1, make_array(1, 2, 3)) from fixed_size_nested_arrays_with_repeating_elements; ---- [[1, 2, 3], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] @@ -3635,24 +3948,47 @@ select array_remove_n(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], ## array_remove_all (aliases: `list_removes`) +# array_remove_all with NULL elements +query ? +select array_remove_all(make_array(1, 2, 2, 1, 1), NULL); +---- +[1, 2, 2, 1, 1] + # array_remove_all scalar function #1 query ??? select array_remove_all(make_array(1, 2, 2, 1, 1), 2), array_remove_all(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0), array_remove_all(make_array('h', 'e', 'l', 'l', 'o'), 'l'); ---- [1, 1, 1] [2.0, 2.0] [h, e, o] +query ??? +select array_remove_all(arrow_cast(make_array(1, 2, 2, 1, 1), 'FixedSizeList(5, Int64)'), 2), array_remove_all(arrow_cast(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 'FixedSizeList(5, Float64)'), 1.0), array_remove_all(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 'l'); +---- +[1, 1, 1] [2.0, 2.0] [h, e, o] + # array_remove_all scalar function #2 (element is list) query ?? select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_remove_all(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); ---- [[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] +query ?? +select array_remove_all(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, List(Int64))'), [4, 5, 6]), + array_remove_all(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, List(Int64))'), [2, 3, 4]); +---- +[[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] + # list_remove_all scalar function #3 (function alias `array_remove_all`) query ??? select list_remove_all(make_array(1, 2, 2, 1, 1), 2), list_remove_all(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0), list_remove_all(make_array('h', 'e', 'l', 'l', 'o'), 'l'); ---- [1, 1, 1] [2.0, 2.0] [h, e, o] +query ?? +select list_remove_all(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, List(Int64))'), [4, 5, 6]), + list_remove_all(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, List(Int64))'), [2, 3, 4]); +---- +[[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] + # array_remove_all scalar function with columns #1 query ? select array_remove_all(column1, column2) from arrays_with_repeating_elements; @@ -3662,6 +3998,14 @@ select array_remove_all(column1, column2) from arrays_with_repeating_elements; [8, 9, 8] [11, 12, 11, 12, 11, 12] +query ? +select array_remove_all(column1, column2) from fixed_arrays_with_repeating_elements; +---- +[1, 1, 3, 1, 3, 3] +[5, 5, 6, 5, 5, 5] +[8, 9, 8] +[11, 12, 11, 12, 11, 12] + # array_remove_all scalar function with columns #2 (element is list) query ? select array_remove_all(column1, column2) from nested_arrays_with_repeating_elements; @@ -3671,6 +4015,14 @@ select array_remove_all(column1, column2) from nested_arrays_with_repeating_elem [[22, 23, 24], [25, 26, 27], [22, 23, 24]] [[31, 32, 33], [34, 35, 36], [31, 32, 33], [34, 35, 36], [31, 32, 33], [34, 35, 36]] +query ? +select array_remove_all(column1, column2) from fixed_size_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [1, 2, 3], [7, 8, 9], [1, 2, 3], [7, 8, 9], [7, 8, 9]] +[[13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15]] +[[22, 23, 24], [25, 26, 27], [22, 23, 24]] +[[31, 32, 33], [34, 35, 36], [31, 32, 33], [34, 35, 36], [31, 32, 33], [34, 35, 36]] + # array_remove_all scalar function with columns and scalars #1 query ?? select array_remove_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2), array_remove_all(column1, 1) from arrays_with_repeating_elements; @@ -3680,6 +4032,14 @@ select array_remove_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2 [1, 2, 2, 4, 5, 4, 4, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [1, 2, 2, 4, 5, 4, 4, 7, 7, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] +query ?? +select array_remove_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2), array_remove_all(column1, 1) from fixed_arrays_with_repeating_elements; +---- +[1, 4, 5, 4, 4, 7, 7, 10, 7, 8] [2, 3, 2, 2, 3, 2, 3] +[1, 2, 2, 5, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] + # array_remove_all scalar function with columns and scalars #2 (element is list) query ?? select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2), array_remove_all(column1, make_array(1, 2, 3)) from nested_arrays_with_repeating_elements; @@ -3689,6 +4049,15 @@ select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +query ?? +select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2), + array_remove_all(column1, make_array(1, 2, 3)) from fixed_size_nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[4, 5, 6], [7, 8, 9], [4, 5, 6], [4, 5, 6], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [13, 14, 15], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + ## trim_array (deprecated) ## array_length (aliases: `list_length`) @@ -4009,6 +4378,21 @@ NULL 1 1 ## array_has/array_has_all/array_has_any +query BB +select array_has([], null), + array_has([1, 2, 3], null); +---- +false false + +#TODO: array_has_all and array_has_any cannot handle NULL +#query BBBB +#select array_has_any([], null), +# array_has_any([1, 2, 3], null), +# array_has_all([], null), +# array_has_all([1, 2, 3], null); +#---- +#false false false false + query BBBBBBBBBBBB select array_has(make_array(1,2), 1), array_has(make_array(1,2,NULL), 1), @@ -4043,6 +4427,23 @@ select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1), ---- true true true true true false true false true false true false +query BBBBBBBBBBBB +select array_has(arrow_cast(make_array(1,2), 'FixedSizeList(2, Int64)'), 1), + array_has(arrow_cast(make_array(1,2,NULL), 'FixedSizeList(3, Int64)'), 1), + array_has(arrow_cast(make_array([2,3], [3,4]), 'FixedSizeList(2, List(Int64))'), make_array(2,3)), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'FixedSizeList(2, List(List(Int64)))'), make_array([1], [2,3])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'FixedSizeList(2, List(List(Int64)))'), make_array([4,5], [6])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'FixedSizeList(2, List(List(Int64)))'), make_array([1])), + array_has(arrow_cast(make_array([[[1]]]), 'FixedSizeList(1, List(List(List(Int64))))'), make_array([[1]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'FixedSizeList(2, List(List(List(Int64))))'), make_array([[2]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'FixedSizeList(2, List(List(List(Int64))))'), make_array([[1], [2]])), + list_has(arrow_cast(make_array(1,2,3), 'FixedSizeList(3, Int64)'), 4), + array_contains(arrow_cast(make_array(1,2,3), 'FixedSizeList(3, Int64)'), 3), + list_contains(arrow_cast(make_array(1,2,3), 'FixedSizeList(3, Int64)'), 0) +; +---- +true true true true true false true false true false true false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -4061,6 +4462,22 @@ from array_has_table_1D; true true true false false false +query B +select array_has(column1, column2) +from fixed_size_array_has_table_1D; +---- +true +false + +#TODO: array_has_all and array_has_any cannot handle FixedSizeList +#query BB +#select array_has_all(column3, column4), +# array_has_any(column5, column6) +#from fixed_size_array_has_table_1D; +#---- +#true true +#false false + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -4079,6 +4496,22 @@ from array_has_table_1D_Float; true true false false false true +query B +select array_has(column1, column2) +from fixed_size_array_has_table_1D_Float; +---- +true +false + +#TODO: array_has_all and array_has_any cannot handle FixedSizeList +#query BB +#select array_has_all(column3, column4), +# array_has_any(column5, column6) +#from fixed_size_array_has_table_1D_Float; +#---- +#true true +#false true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -4097,6 +4530,22 @@ from array_has_table_1D_Boolean; false true true true true true +query B +select array_has(column1, column2) +from fixed_size_array_has_table_1D_Boolean; +---- +false +true + +#TODO: array_has_all and array_has_any cannot handle FixedSizeList +#query BB +#select array_has_all(column3, column4), +# array_has_any(column5, column6) +#from fixed_size_array_has_table_1D_Boolean; +#---- +#true true +#true true + query BBB select array_has(column1, column2), array_has_all(column3, column4), @@ -4115,6 +4564,13 @@ from array_has_table_1D_UTF8; true true false false false true +query B +select array_has(column1, column2) +from fixed_size_array_has_table_1D_UTF8; +---- +true +false + query BB select array_has(column1, column2), array_has_all(column3, column4) @@ -4131,6 +4587,21 @@ from array_has_table_2D; false true true false +query B +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2) +from fixed_size_array_has_table_2D; +---- +false +false + +#TODO: array_has_all and array_has_any cannot handle FixedSizeList +#query B +#select array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +#from fixed_size_array_has_table_2D; +#---- +#true +#false + query B select array_has_all(column1, column2) from array_has_table_2D_float; @@ -4145,6 +4616,14 @@ from array_has_table_2D_float; true false +#TODO: array_has_all and array_has_any cannot handle FixedSizeList +#query B +#select array_has_all(column1, column2) +#from fixed_size_array_has_table_2D_float; +#---- +#false +#false + query B select array_has(column1, column2) from array_has_table_3D; ---- @@ -4167,6 +4646,17 @@ true false true +query B +select array_has(column1, column2) from fixed_size_array_has_table_3D; +---- +false +false +false +false +true +true +true + query BBBB select array_has(column1, make_array(5, 6)), array_has(column1, make_array(7, NULL)), @@ -4195,6 +4685,21 @@ false true false false false false false false false false false false +query BBBB +select array_has(column1, make_array(5, 6)), + array_has(column1, make_array(7, NULL)), + array_has(column2, 5.5), + array_has(column3, 'o') +from fixed_size_arrays; +---- +false false false true +true false true false +true false false true +false true false false +false true false false +false false false false +false false false false + query BBBBBBBBBBBBB select array_has_all(make_array(1,2,3), make_array(1,3)), array_has_all(make_array(1,2,3), make_array(1,4)), @@ -4231,23 +4736,24 @@ select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_ca ---- true false true false false false true true false false true false true -query BBBBBBBBBBBBB -select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), - array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), - array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), - array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), - array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), - array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), - array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), - array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), - array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), - array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), - array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), - array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), - array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) -; ----- -true false true false false false true true false false true false true +#TODO: array_has_all and array_has_any cannot handle FixedSizeList +#query BBBBBBBBBBBBB +#select array_has_all(arrow_cast(make_array(1,2,3), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 3), 'FixedSizeList(2, Int64)')), +# array_has_all(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 4), 'FixedSizeList(2, Int64)')), +# array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2]), 'FixedSizeList(1, List(Int64))')), +# array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,3]), 'FixedSizeList(1, List(Int64))')), +# array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'FixedSizeList(3, List(Int64))')), +# array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1]]), 'FixedSizeList(1, List(List(Int64)))')), +# array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))')), +# array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1,10,100), 'FixedSizeList(3, Int64)')), +# array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(10, 100),'FixedSizeList(2, Int64)')), +# array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'FixedSizeList(2, List(Int64))')), +# array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'FixedSizeList(2, List(Int64))')), +# array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'FixedSizeList(1, List(List(Int64)))')), +# array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'FixedSizeList(2, List(List(Int64)))')) +#; +#---- +#true false true false false false true true false false true false true ## array_distinct @@ -5102,15 +5608,24 @@ drop table nested_arrays; statement ok drop table large_nested_arrays; +statement ok +drop table fixed_size_nested_arrays; + statement ok drop table arrays; statement ok drop table large_arrays; +statement ok +drop table fixed_size_arrays; + statement ok drop table slices; +statement ok +drop table fixed_slices; + statement ok drop table arrayspop; @@ -5187,7 +5702,25 @@ statement ok drop table large_array_intersect_table_3D; statement ok -drop table arrays_values_without_nulls; +drop table fixed_size_array_has_table_1D; + +statement ok +drop table fixed_size_array_has_table_1D_Float; + +statement ok +drop table fixed_size_array_has_table_1D_Boolean; + +statement ok +drop table fixed_size_array_has_table_1D_UTF8; + +statement ok +drop table fixed_size_array_has_table_2D; + +statement ok +drop table fixed_size_array_has_table_2D_float; + +statement ok +drop table fixed_size_array_has_table_3D; statement ok drop table arrays_range; @@ -5198,11 +5731,26 @@ drop table arrays_with_repeating_elements; statement ok drop table large_arrays_with_repeating_elements; +statement ok +drop table fixed_arrays_with_repeating_elements; + statement ok drop table nested_arrays_with_repeating_elements; statement ok drop table large_nested_arrays_with_repeating_elements; +statement ok +drop table fixed_size_nested_arrays_with_repeating_elements; + statement ok drop table flatten_table; + +statement ok +drop table arrays_values_without_nulls; + +statement ok +drop table large_arrays_values_without_nulls; + +statement ok +drop table fixed_size_arrays_values_without_nulls; diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 8b3bd7eac95d..8e2a091423da 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -421,4 +421,4 @@ FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0 query ? select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)'); ---- -[1, 2, 3] +[1, 2, 3] \ No newline at end of file From e566329707e910ce6f8a32822dbafef993a6faaa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 1 Feb 2024 12:05:36 -0500 Subject: [PATCH 5/7] Improve Canonicalize API (#8983) * Improve Canonicalize API * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs Co-authored-by: Junhao Liu * Simplify the API * fix comment * add more flavor in comments --------- Co-authored-by: Junhao Liu --- .../simplify_expressions/expr_simplifier.rs | 73 +++++++++++++++++-- .../simplify_expressions/simplify_exprs.rs | 55 +++++++------- 2 files changed, 91 insertions(+), 37 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 1c1228949171..30140101df7b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -56,6 +56,9 @@ pub struct ExprSimplifier { /// Guarantees about the values of columns. This is provided by the user /// in [ExprSimplifier::with_guarantees()]. guarantees: Vec<(Expr, NullableInterval)>, + /// Should expressions be canonicalized before simplification? Defaults to + /// true + canonicalize: bool, } pub const THRESHOLD_INLINE_INLIST: usize = 3; @@ -70,6 +73,7 @@ impl ExprSimplifier { Self { info, guarantees: vec![], + canonicalize: true, } } @@ -137,6 +141,12 @@ impl ExprSimplifier { let mut inlist_simplifier = InListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); + let expr = if self.canonicalize { + expr.rewrite(&mut Canonicalizer::new())? + } else { + expr + }; + // TODO iterate until no changes are made during rewrite // (evaluating constants can enable new simplifications and // simplifications can enable new constant evaluation) @@ -151,10 +161,6 @@ impl ExprSimplifier { .rewrite(&mut simplifier) } - pub fn canonicalize(&self, expr: Expr) -> Result { - let mut canonicalizer = Canonicalizer::new(); - expr.rewrite(&mut canonicalizer) - } /// Apply type coercion to an [`Expr`] so that it can be /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr). /// @@ -229,6 +235,60 @@ impl ExprSimplifier { self.guarantees = guarantees; self } + + /// Should [`Canonicalizer`] be applied before simplification? + /// + /// If true (the default), the expression will be rewritten to canonical + /// form before simplification. This is useful to ensure that the simplifier + /// can apply all possible simplifications. + /// + /// Some expressions, such as those in some Joins, can not be canonicalized + /// without changing their meaning. In these cases, canonicalization should + /// be disabled. + /// + /// ```rust + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; + /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; + /// use datafusion_physical_expr::execution_props::ExecutionProps; + /// use datafusion_optimizer::simplify_expressions::{ + /// ExprSimplifier, SimplifyContext}; + /// + /// let schema = Schema::new(vec![ + /// Field::new("a", DataType::Int64, false), + /// Field::new("b", DataType::Int64, false), + /// Field::new("c", DataType::Int64, false), + /// ]) + /// .to_dfschema_ref().unwrap(); + /// + /// // Create the simplifier + /// let props = ExecutionProps::new(); + /// let context = SimplifyContext::new(&props) + /// .with_schema(schema); + /// let simplifier = ExprSimplifier::new(context); + /// + /// // Expression: a = c AND 1 = b + /// let expr = col("a").eq(col("c")).and(lit(1).eq(col("b"))); + /// + /// // With canonicalization, the expression is rewritten to canonical form + /// // (though it is no simpler in this case): + /// let canonical = simplifier.simplify(expr.clone()).unwrap(); + /// // Expression has been rewritten to: (c = a AND b = 1) + /// assert_eq!(canonical, col("c").eq(col("a")).and(col("b").eq(lit(1)))); + /// + /// // If canonicalization is disabled, the expression is not changed + /// let non_canonicalized = simplifier + /// .with_canonicalize(false) + /// .simplify(expr.clone()) + /// .unwrap(); + /// + /// assert_eq!(non_canonicalized, expr); + /// ``` + pub fn with_canonicalize(mut self, canonicalize: bool) -> Self { + self.canonicalize = canonicalize; + self + } } /// Canonicalize any BinaryExprs that are not in canonical form @@ -236,7 +296,7 @@ impl ExprSimplifier { /// ` ` is rewritten to ` ` /// /// ` ` is rewritten so that the name of `col1` sorts higher -/// than `col2` (`b > a` would be canonicalized to `a < b`) +/// than `col2` (`a > b` would be canonicalized to `b < a`) struct Canonicalizer {} impl Canonicalizer { @@ -2889,8 +2949,7 @@ mod tests { let simplifier = ExprSimplifier::new( SimplifyContext::new(&execution_props).with_schema(schema), ); - let cano = simplifier.canonicalize(expr)?; - simplifier.simplify(cano) + simplifier.simplify(expr) } fn simplify(expr: Expr) -> Expr { diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index d68474dcde0b..f36cd8f838fb 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -85,44 +85,39 @@ impl SimplifyExpressions { }; let info = SimplifyContext::new(execution_props).with_schema(schema); - let simplifier = ExprSimplifier::new(info); - let new_inputs = plan .inputs() .iter() .map(|input| Self::optimize_internal(input, execution_props)) .collect::>>()?; - let expr = match plan { - // Canonicalize step won't reorder expressions in a Join on clause. - // The left and right expressions in a Join on clause are not commutative, - // since the order of the columns must match the order of the children. - LogicalPlan::Join(_) => { - plan.expressions() - .into_iter() - .map(|e| { - // TODO: unify with `rewrite_preserving_name` - let original_name = e.name_for_alias()?; - let new_e = simplifier.simplify(e)?; - new_e.alias_if_changed(original_name) - }) - .collect::>>()? - } - _ => { - plan.expressions() - .into_iter() - .map(|e| { - // TODO: unify with `rewrite_preserving_name` - let original_name = e.name_for_alias()?; - let cano_e = simplifier.canonicalize(e)?; - let new_e = simplifier.simplify(cano_e)?; - new_e.alias_if_changed(original_name) - }) - .collect::>>()? - } + let simplifier = ExprSimplifier::new(info); + + // The left and right expressions in a Join on clause are not + // commutative, for reasons that are not entirely clear. Thus, do not + // reorder expressions in Join while simplifying. + // + // This is likely related to the fact that order of the columns must + // match the order of the children. see + // https://github.com/apache/arrow-datafusion/pull/8780 for more details + let simplifier = if let LogicalPlan::Join(_) = plan { + simplifier.with_canonicalize(false) + } else { + simplifier }; - plan.with_new_exprs(expr, new_inputs) + let exprs = plan + .expressions() + .into_iter() + .map(|e| { + // TODO: unify with `rewrite_preserving_name` + let original_name = e.name_for_alias()?; + let new_e = simplifier.simplify(e)?; + new_e.alias_if_changed(original_name) + }) + .collect::>>()?; + + plan.with_new_exprs(exprs, new_inputs) } } From 968c05f1c0fe56b528a80d5c22cbdacc862f59e1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 1 Feb 2024 16:04:18 -0500 Subject: [PATCH 6/7] Update env_logger requirement from 0.10 to 0.11 (#8944) * Update env_logger requirement from 0.10 to 0.11 Updates the requirements on [env_logger](https://github.com/rust-cli/env_logger) to permit the latest version. - [Release notes](https://github.com/rust-cli/env_logger/releases) - [Changelog](https://github.com/rust-cli/env_logger/blob/main/CHANGELOG.md) - [Commits](https://github.com/rust-cli/env_logger/compare/v0.10.0...v0.11.0) --- updated-dependencies: - dependency-name: env_logger dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Update cargo.lock --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 55 +++++++++++++++++++++------------ datafusion/optimizer/Cargo.toml | 2 +- test-utils/Cargo.toml | 2 +- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cccca0174113..4c0e7bde26b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ datafusion-sql = { path = "datafusion/sql", version = "35.0.0" } datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "35.0.0" } datafusion-substrait = { path = "datafusion/substrait", version = "35.0.0" } doc-comment = "0.3" -env_logger = "0.10" +env_logger = "0.11" futures = "0.3" half = "2.2.1" indexmap = "2.0.0" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index e89a8f172f74..4d5b3b711d33 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -270,7 +270,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.2.1", + "indexmap 2.2.2", "lexical-core", "num", "serde", @@ -1125,7 +1125,7 @@ dependencies = [ "glob", "half", "hashbrown 0.14.3", - "indexmap 2.2.1", + "indexmap 2.2.2", "itertools", "log", "num-traits", @@ -1274,7 +1274,7 @@ dependencies = [ "half", "hashbrown 0.14.3", "hex", - "indexmap 2.2.1", + "indexmap 2.2.2", "itertools", "log", "md-5", @@ -1305,7 +1305,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.3", - "indexmap 2.2.1", + "indexmap 2.2.2", "itertools", "log", "once_cell", @@ -1677,7 +1677,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 2.2.1", + "indexmap 2.2.2", "slab", "tokio", "tokio-util", @@ -1894,9 +1894,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.1" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433de089bd45971eecf4668ee0ee8f4cec17db4f8bd8f7bc3197a6ce37aa7d9b" +checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -2028,9 +2028,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.152" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libflate" @@ -2236,6 +2236,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.45" @@ -2447,7 +2453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap 2.2.1", + "indexmap 2.2.2", ] [[package]] @@ -2723,9 +2729,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.23" +version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" +checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ "base64", "bytes", @@ -2750,6 +2756,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "sync_wrapper", "system-configuration", "tokio", "tokio-rustls 0.24.1", @@ -2841,9 +2848,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.30" +version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca" +checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ "bitflags 2.4.2", "errno", @@ -3269,6 +3276,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "system-configuration" version = "0.5.1" @@ -3357,11 +3370,12 @@ dependencies = [ [[package]] name = "time" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" +checksum = "fe80ced77cbfb4cb91a94bf72b378b4b6791a0d9b7f09d0be747d1bdff4e68bd" dependencies = [ "deranged", + "num-conv", "powerfmt", "serde", "time-core", @@ -3376,10 +3390,11 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" dependencies = [ + "num-conv", "time-core", ] @@ -3777,9 +3792,9 @@ checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" [[package]] name = "wasm-streams" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" dependencies = [ "futures-util", "js-sys", diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 6aec52ad70d1..e4e9660f93b4 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -53,4 +53,4 @@ regex-syntax = "0.8.0" [dev-dependencies] ctor = { workspace = true } datafusion-sql = { path = "../sql", version = "35.0.0" } -env_logger = "0.10.0" +env_logger = "0.11.0" diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index b9c4db17c098..de6310312b5a 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -25,5 +25,5 @@ edition = { workspace = true } [dependencies] arrow = { workspace = true } datafusion-common = { path = "../datafusion/common" } -env_logger = "0.10.0" +env_logger = "0.11.0" rand = { workspace = true } From 8b50774d6afa04f511b56019febc62fd60d26234 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 1 Feb 2024 16:04:35 -0500 Subject: [PATCH 7/7] Split count_distinct.rs into separate modules (#9087) * Split count_distinct.rs into separate modules * Remove unecessary typedef * Rename * improve module comments --- .../src/aggregate/count_distinct/mod.rs | 257 +++--------------- .../src/aggregate/count_distinct/native.rs | 215 +++++++++++++++ .../src/aggregate/count_distinct/strings.rs | 6 +- 3 files changed, 261 insertions(+), 217 deletions(-) create mode 100644 datafusion/physical-expr/src/aggregate/count_distinct/native.rs diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index 41f7c8729ee3..8baea511c776 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -15,39 +15,36 @@ // specific language governing permissions and limitations // under the License. +mod native; mod strings; use std::any::Any; -use std::cmp::Eq; use std::collections::HashSet; use std::fmt::Debug; -use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field, TimeUnit}; use arrow_array::types::{ - ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, - Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_array::PrimitiveArray; -use datafusion_common::cast::{as_list_array, as_primitive_array}; -use datafusion_common::utils::array_into_list_array; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; +use crate::aggregate::count_distinct::native::{ + FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, +}; use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator; -use crate::aggregate::utils::{down_cast_any_ref, Hashable}; +use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -type DistinctScalarValues = ScalarValue; - /// Expression for a COUNT(DISTINCT) aggregation. #[derive(Debug)] pub struct DistinctCount { @@ -101,46 +98,46 @@ impl AggregateExpr for DistinctCount { use TimeUnit::*; Ok(match &self.state_data_type { - Int8 => Box::new(NativeDistinctCountAccumulator::::new()), - Int16 => Box::new(NativeDistinctCountAccumulator::::new()), - Int32 => Box::new(NativeDistinctCountAccumulator::::new()), - Int64 => Box::new(NativeDistinctCountAccumulator::::new()), - UInt8 => Box::new(NativeDistinctCountAccumulator::::new()), - UInt16 => Box::new(NativeDistinctCountAccumulator::::new()), - UInt32 => Box::new(NativeDistinctCountAccumulator::::new()), - UInt64 => Box::new(NativeDistinctCountAccumulator::::new()), + Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new()), + Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new()), + Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), + Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new()), + UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new()), + UInt16 => Box::new(PrimitiveDistinctCountAccumulator::::new()), + UInt32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), + UInt64 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Decimal128(_, _) => { - Box::new(NativeDistinctCountAccumulator::::new()) + Box::new(PrimitiveDistinctCountAccumulator::::new()) } Decimal256(_, _) => { - Box::new(NativeDistinctCountAccumulator::::new()) + Box::new(PrimitiveDistinctCountAccumulator::::new()) } - Date32 => Box::new(NativeDistinctCountAccumulator::::new()), - Date64 => Box::new(NativeDistinctCountAccumulator::::new()), - Time32(Millisecond) => { - Box::new(NativeDistinctCountAccumulator::::new()) - } + Date32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), + Date64 => Box::new(PrimitiveDistinctCountAccumulator::::new()), + Time32(Millisecond) => Box::new(PrimitiveDistinctCountAccumulator::< + Time32MillisecondType, + >::new()), Time32(Second) => { - Box::new(NativeDistinctCountAccumulator::::new()) - } - Time64(Microsecond) => { - Box::new(NativeDistinctCountAccumulator::::new()) + Box::new(PrimitiveDistinctCountAccumulator::::new()) } + Time64(Microsecond) => Box::new(PrimitiveDistinctCountAccumulator::< + Time64MicrosecondType, + >::new()), Time64(Nanosecond) => { - Box::new(NativeDistinctCountAccumulator::::new()) + Box::new(PrimitiveDistinctCountAccumulator::::new()) } - Timestamp(Microsecond, _) => Box::new(NativeDistinctCountAccumulator::< + Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< TimestampMicrosecondType, >::new()), - Timestamp(Millisecond, _) => Box::new(NativeDistinctCountAccumulator::< + Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< TimestampMillisecondType, >::new()), - Timestamp(Nanosecond, _) => { - Box::new(NativeDistinctCountAccumulator::::new()) - } + Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< + TimestampNanosecondType, + >::new()), Timestamp(Second, _) => { - Box::new(NativeDistinctCountAccumulator::::new()) + Box::new(PrimitiveDistinctCountAccumulator::::new()) } Float16 => Box::new(FloatDistinctCountAccumulator::::new()), @@ -175,9 +172,13 @@ impl PartialEq for DistinctCount { } } +/// General purpose distinct accumulator that works for any DataType by using +/// [`ScalarValue`]. Some types have specialized accumulators that are (much) +/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and +/// [`StringDistinctCountAccumulator`] #[derive(Debug)] struct DistinctCountAccumulator { - values: HashSet, + values: HashSet, state_data_type: DataType, } @@ -186,7 +187,7 @@ impl DistinctCountAccumulator { // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + + (std::mem::size_of::() * self.values.capacity()) + self .values .iter() @@ -199,7 +200,7 @@ impl DistinctCountAccumulator { // calculates the size as accurate as possible, call to this method is expensive fn full_size(&self) -> usize { std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + + (std::mem::size_of::() * self.values.capacity()) + self .values .iter() @@ -260,182 +261,6 @@ impl Accumulator for DistinctCountAccumulator { } } -#[derive(Debug)] -struct NativeDistinctCountAccumulator -where - T: ArrowPrimitiveType + Send, - T::Native: Eq + Hash, -{ - values: HashSet, -} - -impl NativeDistinctCountAccumulator -where - T: ArrowPrimitiveType + Send, - T::Native: Eq + Hash, -{ - fn new() -> Self { - Self { - values: HashSet::default(), - } - } -} - -impl Accumulator for NativeDistinctCountAccumulator -where - T: ArrowPrimitiveType + Send + Debug, - T::Native: Eq + Hash, -{ - fn state(&mut self) -> Result> { - let arr = Arc::new(PrimitiveArray::::from_iter_values( - self.values.iter().cloned(), - )) as ArrayRef; - let list = Arc::new(array_into_list_array(arr)); - Ok(vec![ScalarValue::List(list)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let arr = as_primitive_array::(&values[0])?; - arr.iter().for_each(|value| { - if let Some(value) = value { - self.values.insert(value); - } - }); - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert_eq!( - states.len(), - 1, - "count_distinct states must be single array" - ); - - let arr = as_list_array(&states[0])?; - arr.iter().try_for_each(|maybe_list| { - if let Some(list) = maybe_list { - let list = as_primitive_array::(&list)?; - self.values.extend(list.values()) - }; - Ok(()) - }) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.values.len() as i64))) - } - - fn size(&self) -> usize { - let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) - / 7) - .next_power_of_two(); - - // Size of accumulator - // + size of entry * number of buckets - // + 1 byte for each bucket - // + fixed size of HashSet - std::mem::size_of_val(self) - + std::mem::size_of::() * estimated_buckets - + estimated_buckets - + std::mem::size_of_val(&self.values) - } -} - -#[derive(Debug)] -struct FloatDistinctCountAccumulator -where - T: ArrowPrimitiveType + Send, -{ - values: HashSet, RandomState>, -} - -impl FloatDistinctCountAccumulator -where - T: ArrowPrimitiveType + Send, -{ - fn new() -> Self { - Self { - values: HashSet::default(), - } - } -} - -impl Accumulator for FloatDistinctCountAccumulator -where - T: ArrowPrimitiveType + Send + Debug, -{ - fn state(&mut self) -> Result> { - let arr = Arc::new(PrimitiveArray::::from_iter_values( - self.values.iter().map(|v| v.0), - )) as ArrayRef; - let list = Arc::new(array_into_list_array(arr)); - Ok(vec![ScalarValue::List(list)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let arr = as_primitive_array::(&values[0])?; - arr.iter().for_each(|value| { - if let Some(value) = value { - self.values.insert(Hashable(value)); - } - }); - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert_eq!( - states.len(), - 1, - "count_distinct states must be single array" - ); - - let arr = as_list_array(&states[0])?; - arr.iter().try_for_each(|maybe_list| { - if let Some(list) = maybe_list { - let list = as_primitive_array::(&list)?; - self.values - .extend(list.values().iter().map(|v| Hashable(*v))); - }; - Ok(()) - }) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.values.len() as i64))) - } - - fn size(&self) -> usize { - let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) - / 7) - .next_power_of_two(); - - // Size of accumulator - // + size of entry * number of buckets - // + 1 byte for each bucket - // + fixed size of HashSet - std::mem::size_of_val(self) - + std::mem::size_of::() * estimated_buckets - + estimated_buckets - + std::mem::size_of_val(&self.values) - } -} - #[cfg(test)] mod tests { use arrow::array::{ diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs new file mode 100644 index 000000000000..a44e8b772e5a --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/count_distinct/native.rs @@ -0,0 +1,215 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Specialized implementation of `COUNT DISTINCT` for "Native" arrays such as +//! [`Int64Array`] and [`Float64Array`] +//! +//! [`Int64Array`]: arrow::array::Int64Array +//! [`Float64Array`]: arrow::array::Float64Array +use std::cmp::Eq; +use std::collections::HashSet; +use std::fmt::Debug; +use std::hash::Hash; +use std::sync::Arc; + +use ahash::RandomState; +use arrow::array::ArrayRef; +use arrow_array::types::ArrowPrimitiveType; +use arrow_array::PrimitiveArray; + +use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::ScalarValue; +use datafusion_expr::Accumulator; + +use crate::aggregate::utils::Hashable; + +#[derive(Debug)] +pub(super) struct PrimitiveDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + values: HashSet, +} + +impl PrimitiveDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + pub(super) fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for PrimitiveDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, + T::Native: Eq + Hash, +{ + fn state(&mut self) -> datafusion_common::Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().cloned(), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)); + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(value); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values.extend(list.values()) + }; + Ok(()) + }) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} + +#[derive(Debug)] +pub(super) struct FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + values: HashSet, RandomState>, +} + +impl FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + pub(super) fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, +{ + fn state(&mut self) -> datafusion_common::Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().map(|v| v.0), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)); + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(Hashable(value)); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values + .extend(list.values().iter().map(|v| Hashable(*v))); + }; + Ok(()) + }) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs index d7a9ea5c373d..02d30c350623 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs @@ -15,7 +15,11 @@ // specific language governing permissions and limitations // under the License. -//! Specialized implementation of `COUNT DISTINCT` for `StringArray` and `LargeStringArray` +//! Specialized implementation of `COUNT DISTINCT` for [`StringArray`] +//! and [`LargeStringArray`] +//! +//! [`StringArray`]: arrow::array::StringArray +//! [`LargeStringArray`]: arrow::array::LargeStringArray use ahash::RandomState; use arrow_array::cast::AsArray;