Skip to content

Commit

Permalink
implement json_object_key (and alias json_keys) (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 28, 2024
1 parent 8a758fa commit 2a7c5b2
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 0 deletions.
127 changes: 127 additions & 0 deletions src/json_object_keys.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use std::any::Any;
use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, ListArray, ListBuilder, StringBuilder};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
JsonObjectKeys,
json_object_keys,
json_data path,
r#"Get the keys of a JSON object as an array."#
);

#[derive(Debug)]
pub(super) struct JsonObjectKeys {
signature: Signature,
aliases: [String; 2],
}

impl Default for JsonObjectKeys {
fn default() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: ["json_object_keys".to_string(), "json_keys".to_string()],
}
}
}

impl ScalarUDFImpl for JsonObjectKeys {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
self.aliases[0].as_str()
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
return_type_check(
arg_types,
self.name(),
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
invoke::<ListArrayWrapper, Vec<String>>(
args,
jiter_json_object_keys,
|w| Ok(Arc::new(w.0) as ArrayRef),
keys_to_scalar,
true,
)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

/// Wrapper for a `ListArray` that allows us to implement `FromIterator<Option<Vec<String>>>` as required.
#[derive(Debug)]
struct ListArrayWrapper(ListArray);

impl FromIterator<Option<Vec<String>>> for ListArrayWrapper {
fn from_iter<I: IntoIterator<Item = Option<Vec<String>>>>(iter: I) -> Self {
let values_builder = StringBuilder::new();
let mut builder = ListBuilder::new(values_builder);
for opt_keys in iter {
if let Some(keys) = opt_keys {
for value in keys {
builder.values().append_value(value);
}
builder.append(true);
} else {
builder.append(false);
}
}
Self(builder.finish())
}
}

fn keys_to_scalar(opt_keys: Option<Vec<String>>) -> ScalarValue {
let values_builder = StringBuilder::new();
let mut builder = ListBuilder::new(values_builder);
if let Some(keys) = opt_keys {
for value in keys {
builder.values().append_value(value);
}
builder.append(true);
} else {
builder.append(false);
}
let array = builder.finish();
ScalarValue::List(Arc::new(array))
}

fn jiter_json_object_keys(opt_json: Option<&str>, path: &[JsonPath]) -> Result<Vec<String>, GetError> {
if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) {
match peek {
Peek::Object => {
let mut opt_key = jiter.known_object()?;

let mut keys = Vec::new();
while let Some(key) = opt_key {
keys.push(key.to_string());
jiter.next_skip()?;
opt_key = jiter.next_key()?;
}
Ok(keys)
}
_ => get_err!(),
}
} else {
get_err!()
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod json_get_int;
mod json_get_json;
mod json_get_str;
mod json_length;
mod json_object_keys;
mod rewrite;

pub use common_union::{JsonUnionEncoder, JsonUnionValue};
Expand All @@ -31,6 +32,7 @@ pub mod functions {
pub use crate::json_get_json::json_get_json;
pub use crate::json_get_str::json_get_str;
pub use crate::json_length::json_length;
pub use crate::json_object_keys::json_object_keys;
}

pub mod udfs {
Expand All @@ -43,6 +45,7 @@ pub mod udfs {
pub use crate::json_get_json::json_get_json_udf;
pub use crate::json_get_str::json_get_str_udf;
pub use crate::json_length::json_length_udf;
pub use crate::json_object_keys::json_object_keys_udf;
}

/// Register all JSON UDFs, and [`rewrite::JsonFunctionRewriter`] with the provided [`FunctionRegistry`].
Expand All @@ -65,6 +68,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
json_get_str::json_get_str_udf(),
json_contains::json_contains_udf(),
json_length::json_length_udf(),
json_object_keys::json_object_keys_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
Expand Down
97 changes: 97 additions & 0 deletions tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1432,3 +1432,100 @@ async fn test_dict_filter_contains() {

assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_object_keys() {
let expected = [
"+----------------------------------+",
"| json_object_keys(test.json_data) |",
"+----------------------------------+",
"| [foo] |",
"| [foo] |",
"| [foo] |",
"| [foo] |",
"| [bar] |",
"| |",
"| |",
"+----------------------------------+",
];

let sql = "select json_object_keys(json_data) from test";
let batches = run_query(sql).await.unwrap();
assert_batches_eq!(expected, &batches);

let sql = "select json_object_keys(json_data) from test";
let batches = run_query_dict(sql).await.unwrap();
assert_batches_eq!(expected, &batches);

let sql = "select json_object_keys(json_data) from test";
let batches = run_query_large(sql).await.unwrap();
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_object_keys_many() {
let expected = [
"+-----------------------+",
"| v |",
"+-----------------------+",
"| [foo, bar, spam, ham] |",
"+-----------------------+",
];

let sql = r#"select json_object_keys('{"foo": 1, "bar": 2.2, "spam": true, "ham": []}') as v"#;
let batches = run_query(sql).await.unwrap();
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_object_keys_nested() {
let json = r#"'{"foo": [{"bar": {"spam": true, "ham": []}}]}'"#;

let sql = format!("select json_object_keys({json}) as v");
let batches = run_query(&sql).await.unwrap();
#[rustfmt::skip]
let expected = [
"+-------+",
"| v |",
"+-------+",
"| [foo] |",
"+-------+",
];
assert_batches_eq!(expected, &batches);

let sql = format!("select json_object_keys({json}, 'foo') as v");
let batches = run_query(&sql).await.unwrap();
#[rustfmt::skip]
let expected = [
"+---+",
"| v |",
"+---+",
"| |",
"+---+",
];
assert_batches_eq!(expected, &batches);

let sql = format!("select json_object_keys({json}, 'foo', 0) as v");
let batches = run_query(&sql).await.unwrap();
#[rustfmt::skip]
let expected = [
"+-------+",
"| v |",
"+-------+",
"| [bar] |",
"+-------+",
];
assert_batches_eq!(expected, &batches);

let sql = format!("select json_object_keys({json}, 'foo', 0, 'bar') as v");
let batches = run_query(&sql).await.unwrap();
#[rustfmt::skip]
let expected = [
"+-------------+",
"| v |",
"+-------------+",
"| [spam, ham] |",
"+-------------+",
];
assert_batches_eq!(expected, &batches);
}

0 comments on commit 2a7c5b2

Please sign in to comment.