Skip to content

Commit

Permalink
fix case of dictionary column as path for JSON lookup (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Nov 28, 2024
1 parent 4446d2b commit 9d72f5d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
71 changes: 38 additions & 33 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use std::str::Utf8Error;
use std::sync::Arc;

use datafusion::arrow::array::{
Array, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, StringArray,
StringViewArray, UInt64Array, UnionArray,
Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray,
StringArray, StringViewArray, UInt64Array, UnionArray,
};
use datafusion::arrow::compute::take;
use datafusion::arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType};
use datafusion::arrow::datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type,
};
use datafusion::arrow::downcast_dictionary_array;
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
Expand Down Expand Up @@ -72,6 +74,12 @@ pub enum JsonPath<'s> {
None,
}

impl<'a> From<&'a str> for JsonPath<'a> {
fn from(key: &'a str) -> Self {
JsonPath::Key(key)
}
}

impl From<u64> for JsonPath<'_> {
fn from(index: u64) -> Self {
JsonPath::Index(usize::try_from(index).unwrap())
Expand Down Expand Up @@ -145,32 +153,27 @@ fn invoke_array<C: FromIterator<Option<I>> + 'static, I>(
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
return_dict: bool,
) -> DataFusionResult<ArrayRef> {
if let Some(d) = needle_array.as_any_dictionary_opt() {
// this is the (very rare) case where the needle is a dictionary, it shouldn't affect what we return
invoke_array(json_array, d.values(), to_array, jiter_find, return_dict)
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<StringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, to_array, jiter_find, true, return_dict)
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<LargeStringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, to_array, jiter_find, true, return_dict)
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<StringViewArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, to_array, jiter_find, true, return_dict)
} else if let Some(int_path_array) = needle_array.as_any().downcast_ref::<Int64Array>() {
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
zip_apply(json_array, paths, to_array, jiter_find, false, return_dict)
} else if let Some(int_path_array) = needle_array.as_any().downcast_ref::<UInt64Array>() {
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
zip_apply(json_array, paths, to_array, jiter_find, false, return_dict)
} else {
exec_err!("unexpected second argument type, expected string or int array")
}
downcast_dictionary_array!(
needle_array => match needle_array.values().data_type() {
DataType::Utf8 => zip_apply(json_array, needle_array.downcast_dict::<StringArray>().unwrap(), to_array, jiter_find, true, return_dict),
DataType::LargeUtf8 => zip_apply(json_array, needle_array.downcast_dict::<LargeStringArray>().unwrap(), to_array, jiter_find, true, return_dict),
DataType::Utf8View => zip_apply(json_array, needle_array.downcast_dict::<StringViewArray>().unwrap(), to_array, jiter_find, true, return_dict),
DataType::Int64 => zip_apply(json_array, needle_array.downcast_dict::<Int64Array>().unwrap(), to_array, jiter_find, false, return_dict),
DataType::UInt64 => zip_apply(json_array, needle_array.downcast_dict::<UInt64Array>().unwrap(), to_array, jiter_find, false, return_dict),
other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other),
},
DataType::Utf8 => zip_apply(json_array, needle_array.as_string::<i32>(), to_array, jiter_find, true, return_dict),
DataType::LargeUtf8 => zip_apply(json_array, needle_array.as_string::<i64>(), to_array, jiter_find, true, return_dict),
DataType::Utf8View => zip_apply(json_array, needle_array.as_string_view(), to_array, jiter_find, true, return_dict),
DataType::Int64 => zip_apply(json_array, needle_array.as_primitive::<Int64Type>(), to_array, jiter_find, false, return_dict),
DataType::UInt64 => zip_apply(json_array, needle_array.as_primitive::<UInt64Type>(), to_array, jiter_find, false, return_dict),
other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other)
)
}

fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
fn zip_apply<'a, P: Into<JsonPath<'a>>, C: FromIterator<Option<I>> + 'static, I>(
json_array: &ArrayRef,
path_array: P,
path_array: impl ArrayAccessor<Item = P>,
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
object_lookup: bool,
Expand All @@ -194,18 +197,20 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
to_array(c)
}

fn zip_apply_iter<'a, 'j, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
#[allow(clippy::needless_pass_by_value)] // ArrayAccessor is implemented on references
fn zip_apply_iter<'a, 'j, P: Into<JsonPath<'a>>, C: FromIterator<Option<I>> + 'static, I>(
json_iter: impl Iterator<Item = Option<&'j str>>,
path_array: P,
path_array: impl ArrayAccessor<Item = P>,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> C {
json_iter
.zip(path_array)
.map(|(opt_json, opt_path)| {
if let Some(path) = opt_path {
jiter_find(opt_json, &[path]).ok()
} else {
.enumerate()
.map(|(i, opt_json)| {
if path_array.is_null(i) {
None
} else {
let path = path_array.value(i).into();
jiter_find(opt_json, &[path]).ok()
}
})
.collect::<C>()
Expand Down
12 changes: 8 additions & 4 deletions tests/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result<Sessi
Arc::new(Schema::new(vec![
Field::new("json_data", DataType::Utf8, false),
Field::new("str_key1", DataType::Utf8, false),
Field::new("str_key2", DataType::Utf8, false),
Field::new(
"str_key2",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
false,
),
Field::new("int_key", DataType::Int64, false),
])),
vec![
Expand All @@ -109,12 +113,12 @@ async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result<Sessi
.map(|(_, str_key1, _, _)| *str_key1)
.collect::<Vec<_>>(),
)),
Arc::new(StringArray::from(
Arc::new(
more_nested
.iter()
.map(|(_, _, str_key2, _)| *str_key2)
.collect::<Vec<_>>(),
)),
.collect::<DictionaryArray<Int32Type>>(),
),
Arc::new(Int64Array::from(
more_nested
.iter()
Expand Down

0 comments on commit 9d72f5d

Please sign in to comment.