diff --git a/src/function/as_geojson.rs b/src/function/as_geojson.rs new file mode 100644 index 0000000..80c1605 --- /dev/null +++ b/src/function/as_geojson.rs @@ -0,0 +1,145 @@ +use crate::geo::GeometryArray; +use crate::DFResult; +use arrow_array::cast::AsArray; +use arrow_array::{GenericBinaryArray, LargeStringArray, OffsetSizeTrait, StringArray}; +use arrow_schema::DataType; +use datafusion_common::{internal_datafusion_err, DataFusionError}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility}; +use geozero::ToJson; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub struct AsGeoJsonUdf { + signature: Signature, + aliases: Vec, +} + +impl AsGeoJsonUdf { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Binary]), + TypeSignature::Exact(vec![DataType::LargeBinary]), + ], + Volatility::Immutable, + ), + aliases: vec!["st_asgeojson".to_string()], + } + } +} + +impl ScalarUDFImpl for AsGeoJsonUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ST_AsGeoJSON" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + match arg_types[0] { + DataType::Binary => Ok(DataType::Utf8), + DataType::LargeBinary => Ok(DataType::LargeUtf8), + _ => unreachable!(), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + let arr = args[0].clone().into_array(1)?; + match args[0].data_type() { + DataType::Binary => { + let wkb_arr = arr.as_binary::(); + + let mut json_vec = vec![]; + for i in 0..wkb_arr.geom_len() { + json_vec.push(to_geojson::(wkb_arr, i)?); + } + + Ok(ColumnarValue::Array(Arc::new(StringArray::from(json_vec)))) + } + DataType::LargeBinary => { + let wkb_arr = arr.as_binary::(); + + let mut json_vec = vec![]; + for i in 0..wkb_arr.geom_len() { + json_vec.push(to_geojson::(wkb_arr, i)?); + } + + Ok(ColumnarValue::Array(Arc::new(LargeStringArray::from( + json_vec, + )))) + } + _ => unreachable!(), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn to_geojson( + wkb_arr: &GenericBinaryArray, + geom_index: usize, +) -> DFResult> { + let geom = { + #[cfg(feature = "geos")] + { + wkb_arr.geos_value(geom_index)? + } + #[cfg(not(feature = "geos"))] + { + wkb_arr.geo_value(geom_index)? + } + }; + let json = match geom { + Some(geom) => Some( + geom.to_json() + .map_err(|_| internal_datafusion_err!("Failed to convert geometry to geo json"))?, + ), + None => None, + }; + Ok(json) +} + +impl Default for AsGeoJsonUdf { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use crate::function::{AsGeoJsonUdf, GeomFromTextUdf}; + use arrow::util::pretty::pretty_format_batches; + use datafusion::logical_expr::ScalarUDF; + use datafusion::prelude::SessionContext; + + #[tokio::test] + async fn as_geojson() { + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::from(GeomFromTextUdf::new())); + ctx.register_udf(ScalarUDF::from(AsGeoJsonUdf::new())); + let df = ctx + .sql("select ST_AsGeoJSON(ST_GeomFromText('POINT(-71.064544 42.28787)'))") + .await + .unwrap(); + assert_eq!( + pretty_format_batches(&df.collect().await.unwrap()) + .unwrap() + .to_string(), + "+-------------------------------------------------------------------+ +| ST_AsGeoJSON(ST_GeomFromText(Utf8(\"POINT(-71.064544 42.28787)\"))) | ++-------------------------------------------------------------------+ +| {\"type\": \"Point\", \"coordinates\": [-71.064544,42.28787]} | ++-------------------------------------------------------------------+" + ); + } +} diff --git a/src/function/mod.rs b/src/function/mod.rs index 38c25be..622b0db 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -1,5 +1,6 @@ #[cfg(feature = "geos")] mod as_ewkt; +mod as_geojson; mod as_text; mod box2d; #[cfg(feature = "geos")] @@ -23,6 +24,7 @@ mod translate; #[cfg(feature = "geos")] pub use as_ewkt::*; +pub use as_geojson::*; pub use as_text::*; #[cfg(feature = "geos")] pub use covered_by::*;