From be4df2e18a2d75776419270c516d862cd7c72b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Wed, 27 Mar 2024 10:47:04 +0800 Subject: [PATCH] Add ST_Split --- src/function/mod.rs | 4 ++ src/function/split.rs | 154 ++++++++++++++++++++++++++++++++++++++++++ src/geo/builder.rs | 13 ++++ 3 files changed, 171 insertions(+) create mode 100644 src/function/split.rs diff --git a/src/function/mod.rs b/src/function/mod.rs index 069494c..38c25be 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -16,6 +16,8 @@ mod intersects; #[cfg(feature = "geos")] mod make_envelope; #[cfg(feature = "geos")] +mod split; +#[cfg(feature = "geos")] mod srid; mod translate; @@ -34,5 +36,7 @@ pub use intersects::*; #[cfg(feature = "geos")] pub use make_envelope::*; #[cfg(feature = "geos")] +pub use split::*; +#[cfg(feature = "geos")] pub use srid::*; pub use translate::*; diff --git a/src/function/split.rs b/src/function/split.rs new file mode 100644 index 0000000..91032b2 --- /dev/null +++ b/src/function/split.rs @@ -0,0 +1,154 @@ +use crate::geo::{GeometryArray, GeometryArrayBuilder}; +use crate::DFResult; +use arrow_array::cast::AsArray; +use arrow_array::{BooleanArray, GenericBinaryArray, OffsetSizeTrait}; +use arrow_schema::DataType; +use datafusion_common::{internal_datafusion_err, internal_err}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use geos::Geom; +use rayon::iter::IntoParallelIterator; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub struct SplitUdf { + signature: Signature, + aliases: Vec, +} + +impl SplitUdf { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 2, + vec![DataType::Binary, DataType::LargeBinary], + Volatility::Immutable, + ), + aliases: vec!["st_split".to_string()], + } + } +} + +impl ScalarUDFImpl for SplitUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ST_Split" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(arg_types[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + let (arr0, arr1) = match (args[0].clone(), args[1].clone()) { + (ColumnarValue::Array(arr0), ColumnarValue::Array(arr1)) => (arr0, arr1), + (ColumnarValue::Array(arr0), ColumnarValue::Scalar(scalar)) => { + (arr0.clone(), scalar.to_array_of_size(arr0.len())?) + } + (ColumnarValue::Scalar(scalar), ColumnarValue::Array(arr1)) => { + (scalar.to_array_of_size(arr1.len())?, arr1) + } + (ColumnarValue::Scalar(scalar0), ColumnarValue::Scalar(scalar1)) => { + (scalar0.to_array_of_size(1)?, scalar1.to_array_of_size(1)?) + } + }; + if arr0.len() != arr1.len() { + return internal_err!("Two arrays length is not same"); + } + + match (arr0.data_type(), arr1.data_type()) { + (DataType::Binary, DataType::Binary) => { + let arr0 = arr0.as_binary::(); + let arr1 = arr1.as_binary::(); + split::(arr0, arr1) + } + (DataType::LargeBinary, DataType::Binary) => { + let arr0 = arr0.as_binary::(); + let arr1 = arr1.as_binary::(); + split::(arr0, arr1) + } + (DataType::Binary, DataType::LargeBinary) => { + let arr0 = arr0.as_binary::(); + let arr1 = arr1.as_binary::(); + split::(arr0, arr1) + } + (DataType::LargeBinary, DataType::LargeBinary) => { + let arr0 = arr0.as_binary::(); + let arr1 = arr1.as_binary::(); + split::(arr0, arr1) + } + _ => unreachable!(), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +impl Default for SplitUdf { + fn default() -> Self { + Self::new() + } +} + +fn split( + arr0: &GenericBinaryArray, + arr1: &GenericBinaryArray, +) -> DFResult { + let geom_vec = (0..arr0.geom_len()) + .into_par_iter() + .map( + |geom_index| match (arr0.geos_value(geom_index)?, arr1.geos_value(geom_index)?) { + (Some(geom0), Some(geom1)) => { + let boundary = geom0.boundary().map_err(|e| { + internal_datafusion_err!("Failed to do boundary, error: {}", e) + })?; + let union = boundary.union(&geom1).map_err(|e| { + internal_datafusion_err!("Failed to do union, error: {}", e) + })?; + let (result, ..) = union.polygonize_full().map_err(|e| { + internal_datafusion_err!("Failed to do polygonize_full, error: {}", e) + })?; + + Ok(Some(result)) + } + _ => Ok(None), + }, + ) + .collect::>>>()?; + let builder = GeometryArrayBuilder::::from(&geom_vec); + Ok(ColumnarValue::Array(Arc::new(builder.build()))) +} + +#[cfg(test)] +mod tests { + use crate::function::{GeomFromTextUdf, SplitUdf}; + use arrow::util::pretty::pretty_format_batches; + use datafusion::prelude::SessionContext; + use datafusion_expr::ScalarUDF; + + #[tokio::test] + async fn split() { + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::from(GeomFromTextUdf::new())); + ctx.register_udf(ScalarUDF::from(SplitUdf::new())); + let df = ctx + .sql("select ST_Split(ST_GeomFromText('LINESTRING ( 0 0, 1 1, 2 2 )'), ST_GeomFromText('POINT(1 1)'))") + .await + .unwrap(); + assert_eq!( + pretty_format_batches(&df.collect().await.unwrap()) + .unwrap() + .to_string(), + "" + ); + } +} diff --git a/src/geo/builder.rs b/src/geo/builder.rs index 1c3a757..f8a5156 100644 --- a/src/geo/builder.rs +++ b/src/geo/builder.rs @@ -190,3 +190,16 @@ impl From<&[Option]> for GeometryAr geo_vec.as_slice().into() } } + +#[cfg(feature = "geos")] +impl From<&[Option]> for GeometryArrayBuilder { + fn from(value: &[Option]) -> Self { + let mut builder = GeometryArrayBuilder::::new(WkbDialect::Ewkb, value.len()); + for geom in value { + builder + .append_geos_geometry(geom) + .expect("geometry data is valid"); + } + builder + } +}