diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d00c718..2ffeba4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,15 +70,15 @@ jobs: - id: cache-rust uses: Swatinem/rust-cache@v2 - - run: cargo test --all-features -# - uses: taiki-e/install-action@cargo-llvm-cov -# -# - run: cargo llvm-cov --all-features --codecov --output-path codecov.json -# -# - uses: codecov/codecov-action@v3 -# with: -# files: codecov.json -# env_vars: RUST_VERSION + - uses: taiki-e/install-action@cargo-llvm-cov + + - run: cargo llvm-cov --all-features --codecov --output-path codecov.json + + - uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: codecov.json + env_vars: RUST_VERSION # https://github.com/marketplace/actions/alls-green#why used for branch protection checks check: diff --git a/.gitignore b/.gitignore index 96ef6c0..b471067 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +.idea diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 08b141b..5a847f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: pass_filenames: false - id: clippy name: Clippy - entry: cargo clippy + entry: cargo clippy -- -D warnings types: [rust] language: system pass_filenames: false diff --git a/Cargo.toml b/Cargo.toml index a1f5e88..96b811a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "datafusion-functions-json" -version = "0.41.0" +version = "0.43.0" edition = "2021" description = "JSON functions for DataFusion" readme = "README.md" @@ -8,24 +8,19 @@ license = "Apache-2.0" keywords = ["datafusion", "JSON", "SQL"] categories = ["database-implementations", "parsing"] repository = "https://github.com/datafusion-contrib/datafusion-functions-json/" -rust-version = "1.76.0" +rust-version = "1.79.0" [dependencies] -arrow = "52.2" -arrow-schema = "52.2" -datafusion-common = "41" -datafusion-expr = "41" -datafusion-execution = "41" +datafusion = "43" jiter = "0.5" paste = "1" log = "0.4" [dev-dependencies] -codspeed-criterion-compat = "2.3" +codspeed-criterion-compat = "2.6" criterion = "0.5.1" -datafusion = "41" clap = "4" -tokio = { version = "1.37", features = ["full"] } +tokio = { version = "1.38", features = ["full"] } [lints.clippy] dbg_macro = "deny" diff --git a/README.md b/README.md index 569d609..0c2a984 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,47 @@ To use these functions, you'll just need to call: ```rust datafusion_functions_json::register_all(&mut ctx)?; ``` - To register the below JSON functions in your `SessionContext`. +# Examples + +```sql +-- Create a table with a JSON column stored as a string +CREATE TABLE test_table (id INT, json_col VARCHAR) AS VALUES +(1, '{}'), +(2, '{ "a": 1 }'), +(3, '{ "a": 2 }'), +(4, '{ "a": 1, "b": 2 }'), +(5, '{ "a": 1, "b": 2, "c": 3 }'); + +-- Check if each document contains the key 'b' +SELECT id, json_contains(json_col, 'b') as json_contains FROM test_table; +-- Results in +-- +----+---------------+ +-- | id | json_contains | +-- +----+---------------+ +-- | 1 | false | +-- | 2 | false | +-- | 3 | false | +-- | 4 | true | +-- | 5 | true | +-- +----+---------------+ + +-- Get the value of the key 'a' from each document +SELECT id, json_col->'a' as json_col_a FROM test_table + +-- +----+------------+ +-- | id | json_col_a | +-- +----+------------+ +-- | 1 | {null=} | +-- | 2 | {int=1} | +-- | 3 | {int=2} | +-- | 4 | {int=1} | +-- | 5 | {int=1} | +-- +----+------------+ +``` + + ## Done * [x] `json_contains(json: str, *keys: str | int) -> bool` - true if a JSON string has a specific key (used for the `?` operator) @@ -27,6 +65,11 @@ To register the below JSON functions in your `SessionContext`. * [x] `json_as_text(json: str, *keys: str | int) -> str` - Get any value from a JSON string by its "path", represented as a string (used for the `->>` operator) * [x] `json_length(json: str, *keys: str | int) -> int` - get the length of a JSON string or array +- [x] `->` operator - alias for `json_get` +- [x] `->>` operator - alias for `json_as_text` +- [x] `?` operator - alias for `json_contains` + +### Notes Cast expressions with `json_get` are rewritten to the appropriate method, e.g. ```sql diff --git a/benches/main.rs b/benches/main.rs index c0b9ec0..c12c003 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -1,7 +1,7 @@ use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion}; -use datafusion_common::ScalarValue; -use datafusion_expr::ColumnarValue; +use datafusion::common::ScalarValue; +use datafusion::logical_expr::ColumnarValue; use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf}; fn bench_json_contains(b: &mut Bencher) { @@ -14,7 +14,7 @@ fn bench_json_contains(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - b.iter(|| json_contains.invoke(args).unwrap()); + b.iter(|| json_contains.invoke_batch(args, 1).unwrap()); } fn bench_json_get_str(b: &mut Bencher) { @@ -27,7 +27,7 @@ fn bench_json_get_str(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - b.iter(|| json_get_str.invoke(args).unwrap()); + b.iter(|| json_get_str.invoke_batch(args, 1).unwrap()); } fn criterion_benchmark(c: &mut Criterion) { diff --git a/src/common.rs b/src/common.rs index 9cbee92..f25d838 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,28 +1,70 @@ use std::str::Utf8Error; +use std::sync::Arc; -use arrow::array::{Array, ArrayRef, Int64Array, LargeStringArray, StringArray, UInt64Array}; -use arrow_schema::DataType; -use datafusion_common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; -use datafusion_expr::ColumnarValue; +use datafusion::arrow::array::{ + Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, + StringArray, StringViewArray, UInt64Array, UnionArray, +}; +use datafusion::arrow::compute::take; +use datafusion::arrow::datatypes::{ + ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type, +}; +use datafusion::arrow::downcast_dictionary_array; +use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; use jiter::{Jiter, JiterError, Peek}; -use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array}; +use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL}; -pub fn check_args(args: &[DataType], fn_name: &str) -> DataFusionResult<()> { +/// General implementation of `ScalarUDFImpl::return_type`. +/// +/// # Arguments +/// +/// * `args` - The arguments to the function +/// * `fn_name` - The name of the function +/// * `value_type` - The general return type of the function, might be wrapped in a dictionary depending +/// on the first argument +pub fn return_type_check(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult { let Some(first) = args.first() else { return plan_err!("The '{fn_name}' function requires one or more arguments."); }; - if !(matches!(first, DataType::Utf8 | DataType::LargeUtf8) || is_json_union(first)) { + let first_dict_key_type = dict_key_type(first); + if !(is_str(first) || is_json_union(first) || first_dict_key_type.is_some()) { // if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) { return plan_err!("Unexpected argument type to '{fn_name}' at position 1, expected a string, got {first:?}."); } - args[1..].iter().enumerate().try_for_each(|(index, arg)| match arg { - DataType::Utf8 | DataType::LargeUtf8 | DataType::UInt64 | DataType::Int64 => Ok(()), - t => plan_err!( - "Unexpected argument type to '{fn_name}' at position {}, expected string or int, got {t:?}.", - index + 2 - ), - }) + args.iter().skip(1).enumerate().try_for_each(|(index, arg)| { + if is_str(arg) || is_int(arg) || dict_key_type(arg).is_some() { + Ok(()) + } else { + plan_err!( + "Unexpected argument type to '{fn_name}' at position {}, expected string or int, got {arg:?}.", + index + 2 + ) + } + })?; + match first_dict_key_type { + Some(t) => Ok(DataType::Dictionary(Box::new(t), Box::new(value_type))), + None => Ok(value_type), + } +} + +fn is_str(d: &DataType) -> bool { + matches!(d, DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View) +} + +fn is_int(d: &DataType) -> bool { + // TODO we should support more types of int, but that's a longer task + matches!(d, DataType::UInt64 | DataType::Int64) +} + +fn dict_key_type(d: &DataType) -> Option { + if let DataType::Dictionary(key, value) = d { + if is_str(value) || is_json_union(value) { + return Some(*key.clone()); + } + } + None } #[derive(Debug)] @@ -32,6 +74,12 @@ pub enum JsonPath<'s> { None, } +impl<'a> From<&'a str> for JsonPath<'a> { + fn from(key: &'a str) -> Self { + JsonPath::Key(key) + } +} + impl From for JsonPath<'_> { fn from(index: u64) -> Self { JsonPath::Index(usize::try_from(index).unwrap()) @@ -66,6 +114,7 @@ pub fn invoke> + 'static, I>( jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, to_array: impl Fn(C) -> DataFusionResult, to_scalar: impl Fn(Option) -> ScalarValue, + return_dict: bool, ) -> DataFusionResult { let Some(first_arg) = args.first() else { // I think this can't happen, but I assumed the same about args[1] and I was wrong, so better to be safe @@ -73,95 +122,168 @@ pub fn invoke> + 'static, I>( }; match first_arg { ColumnarValue::Array(json_array) => { - let result_collect = match args.get(1) { + let array = match args.get(1) { Some(ColumnarValue::Array(a)) => { if args.len() > 2 { // TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23 exec_err!("More than 1 path element is not supported when querying JSON using an array.") - } else if let Some(str_path_array) = a.as_any().downcast_ref::() { - let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); - zip_apply(json_array, paths, jiter_find, true) - } else if let Some(str_path_array) = a.as_any().downcast_ref::() { - let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); - zip_apply(json_array, paths, jiter_find, true) - } else if let Some(int_path_array) = a.as_any().downcast_ref::() { - let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into)); - zip_apply(json_array, paths, jiter_find, false) - } else if let Some(int_path_array) = a.as_any().downcast_ref::() { - let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into)); - zip_apply(json_array, paths, jiter_find, false) } else { - exec_err!("unexpected second argument type, expected string or int array") + invoke_array(json_array, a, to_array, jiter_find, return_dict) } } - Some(ColumnarValue::Scalar(_)) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find), - None => scalar_apply(json_array, &[], jiter_find), + Some(ColumnarValue::Scalar(_)) => scalar_apply( + json_array, + &JsonPath::extract_path(args), + to_array, + jiter_find, + return_dict, + ), + None => scalar_apply(json_array, &[], to_array, jiter_find, return_dict), }; - to_array(result_collect?).map(ColumnarValue::from) - } - ColumnarValue::Scalar(ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s)) => { - let path = JsonPath::extract_path(args); - let v = jiter_find(s.as_ref().map(String::as_str), &path).ok(); - Ok(ColumnarValue::Scalar(to_scalar(v))) - } - ColumnarValue::Scalar(ScalarValue::Union(type_id_value, union_fields, _)) => { - let opt_json = json_from_union_scalar(type_id_value, union_fields); - let v = jiter_find(opt_json, &JsonPath::extract_path(args)).ok(); - Ok(ColumnarValue::Scalar(to_scalar(v))) - } - ColumnarValue::Scalar(_) => { - exec_err!("unexpected first argument type, expected string or JSON union") + array.map(ColumnarValue::from) } + ColumnarValue::Scalar(s) => invoke_scalar(s, args, jiter_find, to_scalar), } } -fn zip_apply<'a, P: Iterator>>, C: FromIterator> + 'static, I>( +fn invoke_array> + 'static, I>( + json_array: &ArrayRef, + needle_array: &ArrayRef, + to_array: impl Fn(C) -> DataFusionResult, + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, + return_dict: bool, +) -> DataFusionResult { + downcast_dictionary_array!( + needle_array => match needle_array.values().data_type() { + DataType::Utf8 => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, true, return_dict), + DataType::LargeUtf8 => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, true, return_dict), + DataType::Utf8View => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, true, return_dict), + DataType::Int64 => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, false, return_dict), + DataType::UInt64 => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, false, return_dict), + other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other), + }, + DataType::Utf8 => zip_apply(json_array, needle_array.as_string::(), to_array, jiter_find, true, return_dict), + DataType::LargeUtf8 => zip_apply(json_array, needle_array.as_string::(), to_array, jiter_find, true, return_dict), + DataType::Utf8View => zip_apply(json_array, needle_array.as_string_view(), to_array, jiter_find, true, return_dict), + DataType::Int64 => zip_apply(json_array, needle_array.as_primitive::(), to_array, jiter_find, false, return_dict), + DataType::UInt64 => zip_apply(json_array, needle_array.as_primitive::(), to_array, jiter_find, false, return_dict), + other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other) + ) +} + +fn zip_apply<'a, P: Into>, C: FromIterator> + 'static, I>( json_array: &ArrayRef, - path_array: P, + path_array: impl ArrayAccessor, + to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, object_lookup: bool, -) -> DataFusionResult { - if let Some(string_array) = json_array.as_any().downcast_ref::() { - Ok(zip_apply_iter(string_array.iter(), path_array, jiter_find)) - } else if let Some(large_string_array) = json_array.as_any().downcast_ref::() { - Ok(zip_apply_iter(large_string_array.iter(), path_array, jiter_find)) - } else if let Some(string_array) = nested_json_array(json_array, object_lookup) { - Ok(zip_apply_iter(string_array.iter(), path_array, jiter_find)) - } else { - exec_err!("unexpected json array type {:?}", json_array.data_type()) - } + return_dict: bool, +) -> DataFusionResult { + let c = downcast_dictionary_array!( + json_array => { + let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup, false)?; + return post_process_dict(json_array, values, return_dict); + } + DataType::Utf8 => zip_apply_iter(json_array.as_string::().iter(), path_array, jiter_find), + DataType::LargeUtf8 => zip_apply_iter(json_array.as_string::().iter(), path_array, jiter_find), + DataType::Utf8View => zip_apply_iter(json_array.as_string_view().iter(), path_array, jiter_find), + other => if let Some(string_array) = nested_json_array(json_array, object_lookup) { + zip_apply_iter(string_array.iter(), path_array, jiter_find) + } else { + return exec_err!("unexpected json array type {:?}", other); + } + ); + + to_array(c) } -fn zip_apply_iter<'a, 'j, P: Iterator>>, C: FromIterator> + 'static, I>( +#[allow(clippy::needless_pass_by_value)] // ArrayAccessor is implemented on references +fn zip_apply_iter<'a, 'j, P: Into>, C: FromIterator> + 'static, I>( json_iter: impl Iterator>, - path_array: P, + path_array: impl ArrayAccessor, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, ) -> C { json_iter - .zip(path_array) - .map(|(opt_json, opt_path)| { - if let Some(path) = opt_path { - jiter_find(opt_json, &[path]).ok() - } else { + .enumerate() + .map(|(i, opt_json)| { + if path_array.is_null(i) { None + } else { + let path = path_array.value(i).into(); + jiter_find(opt_json, &[path]).ok() } }) .collect::() } +fn invoke_scalar( + scalar: &ScalarValue, + args: &[ColumnarValue], + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, + to_scalar: impl Fn(Option) -> ScalarValue, +) -> DataFusionResult { + match scalar { + ScalarValue::Dictionary(_, b) => invoke_scalar(b.as_ref(), args, jiter_find, to_scalar), + ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => { + let path = JsonPath::extract_path(args); + let v = jiter_find(s.as_ref().map(String::as_str), &path).ok(); + Ok(ColumnarValue::Scalar(to_scalar(v))) + } + ScalarValue::Union(type_id_value, union_fields, _) => { + let opt_json = json_from_union_scalar(type_id_value, union_fields); + let v = jiter_find(opt_json, &JsonPath::extract_path(args)).ok(); + Ok(ColumnarValue::Scalar(to_scalar(v))) + } + _ => { + exec_err!("unexpected first argument type, expected string or JSON union") + } + } +} + fn scalar_apply>, I>( json_array: &ArrayRef, path: &[JsonPath], + to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, -) -> DataFusionResult { - if let Some(string_array) = json_array.as_any().downcast_ref::() { - Ok(scalar_apply_iter(string_array.iter(), path, jiter_find)) - } else if let Some(large_string_array) = json_array.as_any().downcast_ref::() { - Ok(scalar_apply_iter(large_string_array.iter(), path, jiter_find)) - } else if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) { - Ok(scalar_apply_iter(string_array.iter(), path, jiter_find)) + return_dict: bool, +) -> DataFusionResult { + let c = downcast_dictionary_array!( + json_array => { + let values = scalar_apply(json_array.values(), path, to_array, jiter_find, false)?; + return post_process_dict(json_array, values, return_dict); + } + DataType::Utf8 => scalar_apply_iter(json_array.as_string::().iter(), path, jiter_find), + DataType::LargeUtf8 => scalar_apply_iter(json_array.as_string::().iter(), path, jiter_find), + DataType::Utf8View => scalar_apply_iter(json_array.as_string_view().iter(), path, jiter_find), + other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) { + scalar_apply_iter(string_array.iter(), path, jiter_find) + } else { + return exec_err!("unexpected json array type {:?}", other); + } + ); + to_array(c) +} + +/// Take a dictionary array of JSON data and an array of result values and combine them. +fn post_process_dict( + dict_array: &DictionaryArray, + result_values: ArrayRef, + return_dict: bool, +) -> DataFusionResult { + if return_dict { + if is_json_union(result_values.data_type()) { + // JSON union: post-process the array to set keys to null where the union member is null + let type_ids = result_values.as_any().downcast_ref::().unwrap().type_ids(); + Ok(Arc::new(DictionaryArray::new( + mask_dictionary_keys(dict_array.keys(), type_ids), + result_values, + ))) + } else { + Ok(Arc::new(dict_array.with_values(result_values))) + } } else { - exec_err!("unexpected json array type {:?}", json_array.data_type()) + // this is what cast would do under the hood to unpack a dictionary into an array of its values + Ok(take(&result_values, dict_array.keys(), None)?) } } @@ -235,3 +357,23 @@ impl From for GetError { GetError } } + +/// Set keys to null where the union member is null. +/// +/// This is a workaround to +/// - i.e. that dictionary null is most reliably done if the keys are null. +/// +/// That said, doing this might also be an optimization for cases like null-checking without needing +/// to check the value union array. +fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> PrimitiveArray { + let mut null_mask = vec![true; keys.len()]; + for (i, k) in keys.iter().enumerate() { + match k { + // if the key is non-null and value is non-null, don't mask it out + Some(k) if type_ids[k.as_usize()] != TYPE_ID_NULL => {} + // i.e. key is null or value is null here + _ => null_mask[i] = false, + } + } + PrimitiveArray::new(keys.values().clone(), Some(null_mask.into())) +} diff --git a/src/common_macros.rs b/src/common_macros.rs index f3aa3b1..a2c6cd0 100644 --- a/src/common_macros.rs +++ b/src/common_macros.rs @@ -18,8 +18,8 @@ macro_rules! make_udf_function { ($udf_impl:ty, $expr_fn_name:ident, $($arg:ident)*, $doc:expr) => { paste::paste! { #[doc = $doc] - #[must_use] pub fn $expr_fn_name($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { - datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( + #[must_use] pub fn $expr_fn_name($($arg: datafusion::logical_expr::Expr),*) -> datafusion::logical_expr::Expr { + datafusion::logical_expr::Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf( [< $expr_fn_name _udf >](), vec![$($arg),*], )) @@ -27,16 +27,16 @@ macro_rules! make_udf_function { /// Singleton instance of [`$udf_impl`], ensures the UDF is only created once /// named for example `STATIC_JSON_OBJ_CONTAINS` - static [< STATIC_ $expr_fn_name:upper >]: std::sync::OnceLock> = + static [< STATIC_ $expr_fn_name:upper >]: std::sync::OnceLock> = std::sync::OnceLock::new(); /// ScalarFunction that returns a [`ScalarUDF`] for [`$udf_impl`] /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF - pub fn [< $expr_fn_name _udf >]() -> std::sync::Arc { + /// [`ScalarUDF`]: datafusion::logical_expr::ScalarUDF + pub fn [< $expr_fn_name _udf >]() -> std::sync::Arc { [< STATIC_ $expr_fn_name:upper >] .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( + std::sync::Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl( <$udf_impl>::default(), )) }) diff --git a/src/common_union.rs b/src/common_union.rs index a7278e4..74820ff 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -1,14 +1,15 @@ use std::sync::{Arc, OnceLock}; -use arrow::array::{ - Array, ArrayRef, BooleanArray, Float64Array, Int64Array, ListArray, ListBuilder, NullArray, StringArray, +use datafusion::arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, ListArray, ListBuilder, NullArray, StringArray, StringBuilder, UnionArray, }; -use arrow::buffer::Buffer; -use arrow_schema::{DataType, Field, UnionFields, UnionMode}; -use datafusion_common::ScalarValue; +use datafusion::arrow::buffer::{Buffer, ScalarBuffer}; +use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; +use datafusion::arrow::error::ArrowError; +use datafusion::common::ScalarValue; -pub(crate) fn is_json_union(data_type: &DataType) -> bool { +pub fn is_json_union(data_type: &DataType) -> bool { match data_type { DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(), _ => false, @@ -65,7 +66,7 @@ impl JsonUnion { strings: vec![None; length], arrays: vec![None; length], objects: vec![None; length], - type_ids: vec![0; length], + type_ids: vec![TYPE_ID_NULL; length], index: 0, length, } @@ -97,7 +98,7 @@ impl JsonUnion { } impl TryFrom for UnionArray { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: JsonUnion) -> Result { let children: Vec> = vec![ @@ -120,7 +121,7 @@ impl TryFrom for UnionArray { } impl TryFrom for ListArray { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: JsonUnion) -> Result { let string_builder = StringBuilder::new(); @@ -153,7 +154,7 @@ pub(crate) enum JsonUnionField { Object(String), } -const TYPE_ID_NULL: i8 = 0; +pub(crate) const TYPE_ID_NULL: i8 = 0; const TYPE_ID_BOOL: i8 = 1; const TYPE_ID_INT: i8 = 2; const TYPE_ID_FLOAT: i8 = 3; @@ -266,3 +267,109 @@ impl FromIterator> for JsonUnion { union } } + +pub struct JsonUnionEncoder { + boolean: BooleanArray, + int: Int64Array, + float: Float64Array, + string: StringArray, + array: StringArray, + object: StringArray, + type_ids: ScalarBuffer, +} + +impl JsonUnionEncoder { + #[must_use] + pub fn from_union(union: UnionArray) -> Option { + if is_json_union(union.data_type()) { + let (_, type_ids, _, c) = union.into_parts(); + Some(Self { + boolean: c[1].as_boolean().clone(), + int: c[2].as_primitive().clone(), + float: c[3].as_primitive().clone(), + string: c[4].as_string().clone(), + array: c[5].as_string().clone(), + object: c[6].as_string().clone(), + type_ids, + }) + } else { + None + } + } + + #[must_use] + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.type_ids.len() + } + + /// Get the encodable value for a given index + /// + /// # Panics + /// + /// Panics if the idx is outside the union values or an invalid type id exists in the union. + #[must_use] + pub fn get_value(&self, idx: usize) -> JsonUnionValue { + let type_id = self.type_ids[idx]; + match type_id { + TYPE_ID_NULL => JsonUnionValue::JsonNull, + TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)), + TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)), + TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)), + TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)), + TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)), + TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)), + _ => panic!("Invalid type_id: {type_id}, not a valid JSON type"), + } + } +} + +#[derive(Debug, PartialEq)] +pub enum JsonUnionValue<'a> { + JsonNull, + Bool(bool), + Int(i64), + Float(f64), + Str(&'a str), + Array(&'a str), + Object(&'a str), +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_json_union() { + let json_union = JsonUnion::from_iter(vec![ + Some(JsonUnionField::JsonNull), + Some(JsonUnionField::Bool(true)), + Some(JsonUnionField::Bool(false)), + Some(JsonUnionField::Int(42)), + Some(JsonUnionField::Float(42.0)), + Some(JsonUnionField::Str("foo".to_string())), + Some(JsonUnionField::Array(vec!["[42]".to_string()])), + Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())), + None, + ]); + + let union_array = UnionArray::try_from(json_union).unwrap(); + let encoder = JsonUnionEncoder::from_union(union_array).unwrap(); + + let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect(); + assert_eq!( + values_after, + vec![ + JsonUnionValue::JsonNull, + JsonUnionValue::Bool(true), + JsonUnionValue::Bool(false), + JsonUnionValue::Int(42), + JsonUnionValue::Float(42.0), + JsonUnionValue::Str("foo"), + JsonUnionValue::Array("[42]"), + JsonUnionValue::Object(r#"{"foo": 42}"#), + JsonUnionValue::JsonNull, + ] + ); + } +} diff --git a/src/json_as_text.rs b/src/json_as_text.rs index 9f02470..566ed82 100644 --- a/src/json_as_text.rs +++ b/src/json_as_text.rs @@ -1,14 +1,15 @@ use std::any::Any; use std::sync::Arc; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; -use crate::common_macros::make_udf_function; -use arrow::array::{ArrayRef, StringArray}; -use arrow_schema::DataType; -use datafusion_common::{Result as DataFusionResult, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::array::{ArrayRef, StringArray}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common_macros::make_udf_function; + make_udf_function!( JsonAsText, json_as_text, @@ -45,7 +46,7 @@ impl ScalarUDFImpl for JsonAsText { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Utf8) + return_type_check(arg_types, self.name(), DataType::Utf8) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { @@ -54,6 +55,7 @@ impl ScalarUDFImpl for JsonAsText { jiter_json_as_text, |c| Ok(Arc::new(c) as ArrayRef), ScalarValue::Utf8, + true, ) } diff --git a/src/json_contains.rs b/src/json_contains.rs index 821c04a..bad2ead 100644 --- a/src/json_contains.rs +++ b/src/json_contains.rs @@ -1,12 +1,12 @@ use std::any::Any; use std::sync::Arc; -use arrow_schema::DataType; -use datafusion_common::arrow::array::{ArrayRef, BooleanArray}; -use datafusion_common::{plan_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::arrow::array::{ArrayRef, BooleanArray}; +use datafusion::common::{plan_err, Result, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use crate::common::{check_args, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -48,7 +48,7 @@ impl ScalarUDFImpl for JsonContains { if arg_types.len() < 2 { plan_err!("The 'json_contains' function requires two or more arguments.") } else { - check_args(arg_types, self.name()).map(|()| DataType::Boolean) + return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean) } } @@ -58,6 +58,7 @@ impl ScalarUDFImpl for JsonContains { jiter_json_contains, |c| Ok(Arc::new(c) as ArrayRef), ScalarValue::Boolean, + false, ) } diff --git a/src/json_get.rs b/src/json_get.rs index 95db10b..bb17446 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -1,14 +1,14 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::UnionArray; -use arrow_schema::DataType; -use datafusion_common::arrow::array::ArrayRef; -use datafusion_common::Result as DataFusionResult; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::array::UnionArray; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::{Jiter, NumberAny, NumberInt, Peek}; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; use crate::common_union::{JsonUnion, JsonUnionField}; @@ -50,7 +50,7 @@ impl ScalarUDFImpl for JsonGet { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| JsonUnion::data_type()) + return_type_check(arg_types, self.name(), JsonUnion::data_type()) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { @@ -58,7 +58,7 @@ impl ScalarUDFImpl for JsonGet { let array: UnionArray = c.try_into()?; Ok(Arc::new(array) as ArrayRef) }; - invoke::(args, jiter_json_get_union, to_array, JsonUnionField::scalar_value) + invoke::(args, jiter_json_get_union, to_array, JsonUnionField::scalar_value, true) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_array.rs b/src/json_get_array.rs index 063d228..1f9e1b7 100644 --- a/src/json_get_array.rs +++ b/src/json_get_array.rs @@ -1,14 +1,14 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::ListArray; -use arrow_schema::{DataType, Field}; -use datafusion_common::arrow::array::ArrayRef; -use datafusion_common::{Result as DataFusionResult, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::array::{ArrayRef, ListArray}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::error::Result as DatafusionResult; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::scalar::ScalarValue; use jiter::Peek; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; use crate::common_union::{JsonArrayField, JsonUnion}; @@ -47,19 +47,27 @@ impl ScalarUDFImpl for JsonGetArray { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::List(Field::new("item", DataType::Utf8, true).into())) + fn return_type(&self, arg_types: &[DataType]) -> DatafusionResult { + return_type_check( + arg_types, + self.name(), + DataType::List(Field::new("item", DataType::Utf8, true).into()), + ) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + fn invoke(&self, args: &[ColumnarValue]) -> DatafusionResult { let to_array = |c: JsonUnion| { let array: ListArray = c.try_into()?; Ok(Arc::new(array) as ArrayRef) }; - invoke::(args, jiter_json_get_array, to_array, |i| { - i.map_or_else(|| ScalarValue::Null, Into::into) - }) + invoke::( + args, + jiter_json_get_array, + to_array, + |i| i.map_or_else(|| ScalarValue::Null, Into::into), + true, + ) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 92a4ac9..ff87cd8 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -1,13 +1,13 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, BooleanArray}; -use arrow_schema::DataType; -use datafusion_common::{Result as DataFusionResult, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::array::{ArrayRef, BooleanArray}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetBool { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Boolean) + return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { @@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonGetBool { jiter_json_get_bool, |c| Ok(Arc::new(c) as ArrayRef), ScalarValue::Boolean, + false, ) } diff --git a/src/json_get_float.rs b/src/json_get_float.rs index bed8c67..b099bfd 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -1,13 +1,13 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, Float64Array}; -use arrow_schema::DataType; -use datafusion_common::{Result as DataFusionResult, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::array::{ArrayRef, Float64Array}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberAny, Peek}; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetFloat { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Float64) + return_type_check(arg_types, self.name(), DataType::Float64) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { @@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonGetFloat { jiter_json_get_float, |c| Ok(Arc::new(c) as ArrayRef), ScalarValue::Float64, + true, ) } diff --git a/src/json_get_int.rs b/src/json_get_int.rs index 4f80256..eb37f2b 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -1,13 +1,13 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; -use arrow_schema::DataType; -use datafusion_common::{Result as DataFusionResult, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::array::{ArrayRef, Int64Array}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberInt, Peek}; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetInt { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Int64) + return_type_check(arg_types, self.name(), DataType::Int64) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { @@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonGetInt { jiter_json_get_int, |c| Ok(Arc::new(c) as ArrayRef), ScalarValue::Int64, + true, ) } diff --git a/src/json_get_json.rs b/src/json_get_json.rs index 002702b..86c62e6 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -1,12 +1,12 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, StringArray}; -use arrow_schema::DataType; -use datafusion_common::{Result as DataFusionResult, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::array::{ArrayRef, StringArray}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -45,7 +45,7 @@ impl ScalarUDFImpl for JsonGetJson { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Utf8) + return_type_check(arg_types, self.name(), DataType::Utf8) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { @@ -54,6 +54,7 @@ impl ScalarUDFImpl for JsonGetJson { jiter_json_get_json, |c| Ok(Arc::new(c) as ArrayRef), ScalarValue::Utf8, + true, ) } diff --git a/src/json_get_str.rs b/src/json_get_str.rs index a6f4ad5..753d879 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -1,13 +1,13 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, StringArray}; -use arrow_schema::DataType; -use datafusion_common::{Result as DataFusionResult, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::array::{ArrayRef, StringArray}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetStr { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Utf8) + return_type_check(arg_types, self.name(), DataType::Utf8) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { @@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonGetStr { jiter_json_get_str, |c| Ok(Arc::new(c) as ArrayRef), ScalarValue::Utf8, + true, ) } diff --git a/src/json_length.rs b/src/json_length.rs index b1bb900..5cc934a 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -1,13 +1,13 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, UInt64Array}; -use arrow_schema::DataType; -use datafusion_common::{Result as DataFusionResult, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::arrow::array::{ArrayRef, UInt64Array}; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonLength { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::UInt64) + return_type_check(arg_types, self.name(), DataType::UInt64) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { @@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonLength { jiter_json_length, |c| Ok(Arc::new(c) as ArrayRef), ScalarValue::UInt64, + true, ) } diff --git a/src/json_object_keys.rs b/src/json_object_keys.rs new file mode 100644 index 0000000..e89c0bd --- /dev/null +++ b/src/json_object_keys.rs @@ -0,0 +1,127 @@ +use std::any::Any; +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, ListArray, ListBuilder, StringBuilder}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::Peek; + +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common_macros::make_udf_function; + +make_udf_function!( + JsonObjectKeys, + json_object_keys, + json_data path, + r#"Get the keys of a JSON object as an array."# +); + +#[derive(Debug)] +pub(super) struct JsonObjectKeys { + signature: Signature, + aliases: [String; 2], +} + +impl Default for JsonObjectKeys { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: ["json_object_keys".to_string(), "json_keys".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonObjectKeys { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + return_type_check( + arg_types, + self.name(), + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + ) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + invoke::>( + args, + jiter_json_object_keys, + |w| Ok(Arc::new(w.0) as ArrayRef), + keys_to_scalar, + true, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Wrapper for a `ListArray` that allows us to implement `FromIterator>>` as required. +#[derive(Debug)] +struct ListArrayWrapper(ListArray); + +impl FromIterator>> for ListArrayWrapper { + fn from_iter>>>(iter: I) -> Self { + let values_builder = StringBuilder::new(); + let mut builder = ListBuilder::new(values_builder); + for opt_keys in iter { + if let Some(keys) = opt_keys { + for value in keys { + builder.values().append_value(value); + } + builder.append(true); + } else { + builder.append(false); + } + } + Self(builder.finish()) + } +} + +fn keys_to_scalar(opt_keys: Option>) -> ScalarValue { + let values_builder = StringBuilder::new(); + let mut builder = ListBuilder::new(values_builder); + if let Some(keys) = opt_keys { + for value in keys { + builder.values().append_value(value); + } + builder.append(true); + } else { + builder.append(false); + } + let array = builder.finish(); + ScalarValue::List(Arc::new(array)) +} + +fn jiter_json_object_keys(opt_json: Option<&str>, path: &[JsonPath]) -> Result, GetError> { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { + match peek { + Peek::Object => { + let mut opt_key = jiter.known_object()?; + + let mut keys = Vec::new(); + while let Some(key) = opt_key { + keys.push(key.to_string()); + jiter.next_skip()?; + opt_key = jiter.next_key()?; + } + Ok(keys) + } + _ => get_err!(), + } + } else { + get_err!() + } +} diff --git a/src/lib.rs b/src/lib.rs index 8f55540..118796c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,10 @@ -use datafusion_common::Result; -use datafusion_execution::FunctionRegistry; -use datafusion_expr::ScalarUDF; use log::debug; use std::sync::Arc; +use datafusion::common::Result; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::ScalarUDF; + mod common; mod common_macros; mod common_union; @@ -17,8 +18,11 @@ mod json_get_int; mod json_get_json; mod json_get_str; mod json_length; +mod json_object_keys; mod rewrite; +pub use common_union::{JsonUnionEncoder, JsonUnionValue}; + pub mod functions { pub use crate::json_as_text::json_as_text; pub use crate::json_contains::json_contains; @@ -30,6 +34,7 @@ pub mod functions { pub use crate::json_get_json::json_get_json; pub use crate::json_get_str::json_get_str; pub use crate::json_length::json_length; + pub use crate::json_object_keys::json_object_keys; } pub mod udfs { @@ -43,6 +48,7 @@ pub mod udfs { pub use crate::json_get_json::json_get_json_udf; pub use crate::json_get_str::json_get_str_udf; pub use crate::json_length::json_length_udf; + pub use crate::json_object_keys::json_object_keys_udf; } /// Register all JSON UDFs, and [`rewrite::JsonFunctionRewriter`] with the provided [`FunctionRegistry`]. @@ -66,6 +72,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { json_get_str::json_get_str_udf(), json_contains::json_contains_udf(), json_length::json_length_udf(), + json_object_keys::json_object_keys_udf(), ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/src/rewrite.rs b/src/rewrite.rs index 814bd62..976db7e 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -1,13 +1,14 @@ -use arrow::datatypes::DataType; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; -use datafusion_common::DFSchema; -use datafusion_common::Result; -use datafusion_expr::expr::{Alias, Cast, Expr, ScalarFunction}; -use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; -use datafusion_expr::sqlparser::ast::BinaryOperator; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::config::ConfigOptions; +use datafusion::common::tree_node::Transformed; +use datafusion::common::DFSchema; +use datafusion::common::Result; +use datafusion::logical_expr::expr::{Alias, Cast, Expr, ScalarFunction}; +use datafusion::logical_expr::expr_rewriter::FunctionRewrite; +use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; +use datafusion::logical_expr::sqlparser::ast::BinaryOperator; +#[derive(Debug)] pub(crate) struct JsonFunctionRewriter; impl FunctionRewrite for JsonFunctionRewriter { @@ -51,14 +52,23 @@ fn optimise_json_get_cast(cast: &Cast) -> Option> { fn unnest_json_calls(func: &ScalarFunction) -> Option> { if !matches!( func.func.name(), - "json_get" | "json_get_bool" | "json_get_float" | "json_get_int" | "json_get_json" | "json_get_str" + "json_get" + | "json_get_bool" + | "json_get_float" + | "json_get_int" + | "json_get_json" + | "json_get_str" + | "json_as_text" ) { return None; } let mut outer_args_iter = func.args.iter(); let first_arg = outer_args_iter.next()?; let inner_func = extract_scalar_function(first_arg)?; - if inner_func.func.name() != "json_get" { + + // both json_get and json_as_text would produce new JSON to be processed by the outer + // function so can be inlined + if !matches!(inner_func.func.name(), "json_get" | "json_as_text") { return None; } diff --git a/tests/main.rs b/tests/main.rs index 5e70cc7..250aae3 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1,11 +1,16 @@ -use arrow_schema::DataType; +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, RecordBatch}; +use datafusion::arrow::datatypes::{Field, Int8Type, Schema}; +use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType}; use datafusion::assert_batches_eq; -use datafusion_common::ScalarValue; +use datafusion::common::ScalarValue; +use datafusion::logical_expr::ColumnarValue; +use datafusion::prelude::SessionContext; +use datafusion_functions_json::udfs::json_get_str_udf; +use utils::{create_context, display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params}; mod utils; -use datafusion_expr::ColumnarValue; -use datafusion_functions_json::udfs::json_get_str_udf; -use utils::{display_val, logical_plan, run_query, run_query_large, run_query_params}; #[tokio::test] async fn test_json_contains() { @@ -483,7 +488,7 @@ async fn test_json_get_array() { let expected = [ "+------------------+----------------------------------------------------+", - "| name | unnest(json_get_array(test.json_data,Utf8(\"foo\"))) |", + "| name | UNNEST(json_get_array(test.json_data,Utf8(\"foo\"))) |", "+------------------+----------------------------------------------------+", "| object_foo_array | 1 |", "| object_foo_array | true |", @@ -516,7 +521,7 @@ fn test_json_get_utf8() { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - let ColumnarValue::Scalar(sv) = json_get_str.invoke(args).unwrap() else { + let ColumnarValue::Scalar(sv) = json_get_str.invoke_batch(args, 1).unwrap() else { panic!("expected scalar") }; @@ -534,7 +539,7 @@ fn test_json_get_large_utf8() { ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("aa".to_string()))), ]; - let ColumnarValue::Scalar(sv) = json_get_str.invoke(args).unwrap() else { + let ColumnarValue::Scalar(sv) = json_get_str.invoke_batch(args, 1).unwrap() else { panic!("expected scalar") }; @@ -896,6 +901,38 @@ async fn test_plan_arrow_double_nested() { assert_eq!(lines, expected); } +#[tokio::test] +async fn test_double_arrow_double_nested() { + let batches = run_query("select name, json_data->>'foo'->>0 from test").await.unwrap(); + + let expected = [ + "+------------------+---------------------------------------------+", + "| name | test.json_data ->> Utf8(\"foo\") ->> Int64(0) |", + "+------------------+---------------------------------------------+", + "| object_foo | |", + "| object_foo_array | 1 |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+---------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_plan_double_arrow_double_nested() { + let lines = logical_plan(r"explain select json_data->>'foo'->>0 from test").await; + + let expected = [ + "Projection: json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data ->> Utf8(\"foo\") ->> Int64(0)", + " TableScan: test projection=[json_data]", + ]; + + assert_eq!(lines, expected); +} + #[tokio::test] async fn test_arrow_double_nested_cast() { let batches = run_query("select name, (json_data->'foo'->0)::int from test") @@ -930,6 +967,41 @@ async fn test_plan_arrow_double_nested_cast() { assert_eq!(lines, expected); } +#[tokio::test] +async fn test_double_arrow_double_nested_cast() { + let batches = run_query("select name, (json_data->>'foo'->>0)::int from test") + .await + .unwrap(); + + let expected = [ + "+------------------+---------------------------------------------+", + "| name | test.json_data ->> Utf8(\"foo\") ->> Int64(0) |", + "+------------------+---------------------------------------------+", + "| object_foo | |", + "| object_foo_array | 1 |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+---------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_plan_double_arrow_double_nested_cast() { + let lines = logical_plan(r"explain select (json_data->>'foo'->>0)::int from test").await; + + // NB: json_as_text(..)::int is NOT the same as `json_get_int(..)`, hence the cast is not rewritten + let expected = [ + "Projection: CAST(json_as_text(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data ->> Utf8(\"foo\") ->> Int64(0) AS Int32)", + " TableScan: test projection=[json_data]", + ]; + + assert_eq!(lines, expected); +} + #[tokio::test] async fn test_arrow_nested_columns() { let expected = [ @@ -965,10 +1037,18 @@ async fn test_arrow_nested_double_columns() { } #[tokio::test] -async fn test_lexical_precedence_wrong() { +async fn test_lexical_precedence_correct() { + #[rustfmt::skip] + let expected = [ + "+------+", + "| v |", + "+------+", + "| true |", + "+------+", + ]; let sql = r#"select '{"a": "b"}'->>'a'='b' as v"#; - let err = run_query(sql).await.unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_as_text' at position 2, expected string or int, got Boolean."); + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); } #[tokio::test] @@ -1092,6 +1172,28 @@ async fn test_arrow_union_is_null() { assert_batches_eq!(expected, &batches); } +#[tokio::test] +async fn test_arrow_union_is_null_dict_encoded() { + let batches = run_query_dict("select name, (json_data->'foo') is null from test") + .await + .unwrap(); + + let expected = [ + "+------------------+---------------------------------------+", + "| name | test.json_data -> Utf8(\"foo\") IS NULL |", + "+------------------+---------------------------------------+", + "| object_foo | false |", + "| object_foo_array | false |", + "| object_foo_obj | false |", + "| object_foo_null | true |", + "| object_bar | true |", + "| list_foo | true |", + "| invalid_json | true |", + "+------------------+---------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + #[tokio::test] async fn test_arrow_union_is_not_null() { let batches = run_query("select name, (json_data->'foo') is not null from test") @@ -1114,6 +1216,28 @@ async fn test_arrow_union_is_not_null() { assert_batches_eq!(expected, &batches); } +#[tokio::test] +async fn test_arrow_union_is_not_null_dict_encoded() { + let batches = run_query_dict("select name, (json_data->'foo') is not null from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-------------------------------------------+", + "| name | test.json_data -> Utf8(\"foo\") IS NOT NULL |", + "+------------------+-------------------------------------------+", + "| object_foo | true |", + "| object_foo_array | true |", + "| object_foo_obj | true |", + "| object_foo_null | false |", + "| object_bar | false |", + "| list_foo | false |", + "| invalid_json | false |", + "+------------------+-------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + #[tokio::test] async fn test_arrow_scalar_union_is_null() { let batches = run_query( @@ -1158,3 +1282,338 @@ async fn test_arrow_cast_numeric() { let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); } + +#[tokio::test] +async fn test_dict_haystack() { + let sql = "select json_get(json_data, 'foo') v from dicts"; + let expected = [ + "+-----------------------+", + "| v |", + "+-----------------------+", + "| {object={\"bar\": [0]}} |", + "| |", + "| |", + "| |", + "+-----------------------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_haystack_filter() { + let sql = "select json_data v from dicts where json_get(json_data, 'foo') is not null"; + let expected = [ + "+-------------------------+", + "| v |", + "+-------------------------+", + "| {\"foo\": {\"bar\": [0]}} |", + "+-------------------------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_haystack_needle() { + let sql = "select json_get(json_get(json_data, str_key1), str_key2) v from dicts"; + let expected = [ + "+-------------+", + "| v |", + "+-------------+", + "| {array=[0]} |", + "| |", + "| |", + "| |", + "+-------------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_length() { + let sql = "select json_length(json_data) v from dicts"; + #[rustfmt::skip] + let expected = [ + "+---+", + "| v |", + "+---+", + "| 1 |", + "| 1 |", + "| 2 |", + "| 2 |", + "+---+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_contains() { + let sql = "select json_contains(json_data, str_key2) v from dicts"; + let expected = [ + "+-------+", + "| v |", + "+-------+", + "| false |", + "| false |", + "| true |", + "| true |", + "+-------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_contains_where() { + let sql = "select str_key2 from dicts where json_contains(json_data, str_key2)"; + let expected = [ + "+----------+", + "| str_key2 |", + "+----------+", + "| spam |", + "| snap |", + "+----------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_get_int() { + let sql = "select json_get_int(json_data, str_key2) v from dicts"; + #[rustfmt::skip] + let expected = [ + "+---+", + "| v |", + "+---+", + "| |", + "| |", + "| 1 |", + "| 2 |", + "+---+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +async fn build_dict_schema() -> SessionContext { + let mut builder = StringDictionaryBuilder::::new(); + builder.append(r#"{"foo": "bar"}"#).unwrap(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append("nah").unwrap(); + builder.append(r#"{"baz": "abcd"}"#).unwrap(); + builder.append_null(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append(r#"{"baz": "fizz"}"#).unwrap(); + builder.append_null(); + + let dict = builder.finish(); + + assert_eq!(dict.len(), 10); + assert_eq!(dict.values().len(), 4); + + let array = Arc::new(dict) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + true, + )])); + + let data = RecordBatch::try_new(schema.clone(), vec![array]).unwrap(); + + let ctx = create_context().await.unwrap(); + ctx.register_batch("data", data).unwrap(); + ctx +} + +#[tokio::test] +async fn test_dict_filter() { + let ctx = build_dict_schema().await; + + let sql = "select json_get(x, 'baz') v from data"; + let expected = [ + "+------------+", + "| v |", + "+------------+", + "| |", + "| {str=fizz} |", + "| |", + "| {str=abcd} |", + "| |", + "| {str=fizz} |", + "| {str=fizz} |", + "| {str=fizz} |", + "| {str=fizz} |", + "| |", + "+------------+", + ]; + + let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); + + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_filter_is_not_null() { + let ctx = build_dict_schema().await; + let sql = "select x from data where json_get(x, 'baz') is not null"; + let expected = [ + "+-----------------+", + "| x |", + "+-----------------+", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"abcd\"} |", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"fizz\"} |", + "+-----------------+", + ]; + + let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); + + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_filter_contains() { + let ctx = build_dict_schema().await; + let sql = "select x from data where json_contains(x, 'baz')"; + let expected = [ + "+-----------------+", + "| x |", + "+-----------------+", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"abcd\"} |", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"fizz\"} |", + "| {\"baz\": \"fizz\"} |", + "+-----------------+", + ]; + + let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); + + assert_batches_eq!(expected, &batches); + + // test with a boolean OR as well + let batches = ctx + .sql(&format!("{sql} or false")) + .await + .unwrap() + .collect() + .await + .unwrap(); + + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_object_keys() { + let expected = [ + "+----------------------------------+", + "| json_object_keys(test.json_data) |", + "+----------------------------------+", + "| [foo] |", + "| [foo] |", + "| [foo] |", + "| [foo] |", + "| [bar] |", + "| |", + "| |", + "+----------------------------------+", + ]; + + let sql = "select json_object_keys(json_data) from test"; + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); + + let sql = "select json_object_keys(json_data) from test"; + let batches = run_query_dict(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); + + let sql = "select json_object_keys(json_data) from test"; + let batches = run_query_large(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_object_keys_many() { + let expected = [ + "+-----------------------+", + "| v |", + "+-----------------------+", + "| [foo, bar, spam, ham] |", + "+-----------------------+", + ]; + + let sql = r#"select json_object_keys('{"foo": 1, "bar": 2.2, "spam": true, "ham": []}') as v"#; + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_object_keys_nested() { + let json = r#"'{"foo": [{"bar": {"spam": true, "ham": []}}]}'"#; + + let sql = format!("select json_object_keys({json}) as v"); + let batches = run_query(&sql).await.unwrap(); + #[rustfmt::skip] + let expected = [ + "+-------+", + "| v |", + "+-------+", + "| [foo] |", + "+-------+", + ]; + assert_batches_eq!(expected, &batches); + + let sql = format!("select json_object_keys({json}, 'foo') as v"); + let batches = run_query(&sql).await.unwrap(); + #[rustfmt::skip] + let expected = [ + "+---+", + "| v |", + "+---+", + "| |", + "+---+", + ]; + assert_batches_eq!(expected, &batches); + + let sql = format!("select json_object_keys({json}, 'foo', 0) as v"); + let batches = run_query(&sql).await.unwrap(); + #[rustfmt::skip] + let expected = [ + "+-------+", + "| v |", + "+-------+", + "| [bar] |", + "+-------+", + ]; + assert_batches_eq!(expected, &batches); + + let sql = format!("select json_object_keys({json}, 'foo', 0, 'bar') as v"); + let batches = run_query(&sql).await.unwrap(); + #[rustfmt::skip] + let expected = [ + "+-------------+", + "| v |", + "+-------------+", + "| [spam, ham] |", + "+-------------+", + ]; + assert_batches_eq!(expected, &batches); +} diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 4f98056..c220bc7 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,21 +1,27 @@ #![allow(dead_code)] use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; -use arrow::datatypes::{DataType, Field, Schema}; -use arrow::util::display::{ArrayFormatter, FormatOptions}; -use arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch}; - +use datafusion::arrow::array::{ + ArrayRef, DictionaryArray, Int32Array, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array, +}; +use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Int64Type, Schema, UInt32Type, UInt8Type}; +use datafusion::arrow::util::display::{ArrayFormatter, FormatOptions}; +use datafusion::arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch}; +use datafusion::common::ParamValues; use datafusion::error::Result; use datafusion::execution::context::SessionContext; -use datafusion_common::ParamValues; -use datafusion_execution::config::SessionConfig; +use datafusion::prelude::SessionConfig; use datafusion_functions_json::register_all; -fn create_test_table(large_utf8: bool) -> Result { +pub async fn create_context() -> Result { let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres"); let mut ctx = SessionContext::new_with_config(config); register_all(&mut ctx)?; + Ok(ctx) +} + +async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result { + let ctx = create_context().await?; let test_data = [ ("object_foo", r#" {"foo": "abc"} "#), @@ -30,11 +36,20 @@ fn create_test_table(large_utf8: bool) -> Result { ("invalid_json", "is not json"), ]; let json_values = test_data.iter().map(|(_, json)| *json).collect::>(); - let (json_data_type, json_array): (DataType, ArrayRef) = if large_utf8 { + let (mut json_data_type, mut json_array): (DataType, ArrayRef) = if large_utf8 { (DataType::LargeUtf8, Arc::new(LargeStringArray::from(json_values))) } else { (DataType::Utf8, Arc::new(StringArray::from(json_values))) }; + + if dict_encoded { + json_data_type = DataType::Dictionary(DataType::Int32.into(), json_data_type.into()); + json_array = Arc::new(DictionaryArray::::new( + Int32Array::from_iter_values(0..(json_array.len() as i32)), + json_array, + )); + } + let test_batch = RecordBatch::try_new( Arc::new(Schema::new(vec![ Field::new("name", DataType::Utf8, false), @@ -84,7 +99,11 @@ fn create_test_table(large_utf8: bool) -> Result { Arc::new(Schema::new(vec![ Field::new("json_data", DataType::Utf8, false), Field::new("str_key1", DataType::Utf8, false), - Field::new("str_key2", DataType::Utf8, false), + Field::new( + "str_key2", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), Field::new("int_key", DataType::Int64, false), ])), vec![ @@ -97,12 +116,12 @@ fn create_test_table(large_utf8: bool) -> Result { .map(|(_, str_key1, _, _)| *str_key1) .collect::>(), )), - Arc::new(StringArray::from( + Arc::new( more_nested .iter() .map(|(_, _, str_key2, _)| *str_key2) - .collect::>(), - )), + .collect::>(), + ), Arc::new(Int64Array::from( more_nested .iter() @@ -113,16 +132,85 @@ fn create_test_table(large_utf8: bool) -> Result { )?; ctx.register_batch("more_nested", more_nested_batch)?; + let dict_data = [ + (r#" {"foo": {"bar": [0]}} "#, "foo", "bar", 0), + (r#" {"bar": "snap"} "#, "foo", "spam", 0), + (r#" {"spam": 1, "snap": 2} "#, "foo", "spam", 0), + (r#" {"spam": 1, "snap": 2} "#, "foo", "snap", 0), + ]; + let dict_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new( + "json_data", + DataType::Dictionary(DataType::UInt32.into(), DataType::Utf8.into()), + false, + ), + Field::new( + "str_key1", + DataType::Dictionary(DataType::UInt8.into(), DataType::LargeUtf8.into()), + false, + ), + Field::new( + "str_key2", + DataType::Dictionary(DataType::UInt8.into(), DataType::Utf8View.into()), + false, + ), + Field::new( + "int_key", + DataType::Dictionary(DataType::Int64.into(), DataType::UInt64.into()), + false, + ), + ])), + vec![ + Arc::new(DictionaryArray::::new( + UInt32Array::from_iter_values(dict_data.iter().enumerate().map(|(id, _)| id as u32)), + Arc::new(StringArray::from( + dict_data.iter().map(|(json, _, _, _)| *json).collect::>(), + )), + )), + Arc::new(DictionaryArray::::new( + UInt8Array::from_iter_values(dict_data.iter().enumerate().map(|(id, _)| id as u8)), + Arc::new(LargeStringArray::from( + dict_data + .iter() + .map(|(_, str_key1, _, _)| *str_key1) + .collect::>(), + )), + )), + Arc::new(DictionaryArray::::new( + UInt8Array::from_iter_values(dict_data.iter().enumerate().map(|(id, _)| id as u8)), + Arc::new(StringViewArray::from( + dict_data + .iter() + .map(|(_, _, str_key2, _)| *str_key2) + .collect::>(), + )), + )), + Arc::new(DictionaryArray::::new( + Int64Array::from_iter_values(dict_data.iter().enumerate().map(|(id, _)| id as i64)), + Arc::new(UInt64Array::from_iter_values( + dict_data.iter().map(|(_, _, _, int_key)| *int_key as u64), + )), + )), + ], + )?; + ctx.register_batch("dicts", dict_batch)?; + Ok(ctx) } pub async fn run_query(sql: &str) -> Result> { - let ctx = create_test_table(false)?; + let ctx = create_test_table(false, false).await?; ctx.sql(sql).await?.collect().await } pub async fn run_query_large(sql: &str) -> Result> { - let ctx = create_test_table(true)?; + let ctx = create_test_table(true, false).await?; + ctx.sql(sql).await?.collect().await +} + +pub async fn run_query_dict(sql: &str) -> Result> { + let ctx = create_test_table(false, true).await?; ctx.sql(sql).await?.collect().await } @@ -131,7 +219,7 @@ pub async fn run_query_params( large_utf8: bool, query_values: impl Into, ) -> Result> { - let ctx = create_test_table(large_utf8)?; + let ctx = create_test_table(large_utf8, false).await?; ctx.sql(sql).await?.with_param_values(query_values)?.collect().await } @@ -152,5 +240,5 @@ pub async fn logical_plan(sql: &str) -> Vec { let batches = run_query(sql).await.unwrap(); let plan_col = batches[0].column(1).as_any().downcast_ref::().unwrap(); let logical_plan = plan_col.value(0); - logical_plan.split('\n').map(std::string::ToString::to_string).collect() + logical_plan.split('\n').map(ToString::to_string).collect() }