diff --git a/src/function/intersects.rs b/src/function/intersects.rs index 0f1c4e9..b719862 100644 --- a/src/function/intersects.rs +++ b/src/function/intersects.rs @@ -3,6 +3,7 @@ use crate::DFResult; use arrow_array::cast::AsArray; use arrow_array::{Array, BooleanArray, GenericBinaryArray, OffsetSizeTrait}; use arrow_schema::DataType; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -44,15 +45,29 @@ impl ScalarUDFImpl for IntersectsUdf { } fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let arr0 = args[0].clone().into_array(1)?; - let arr1 = args[1].clone().into_array(1)?; + let (arr0, arr1) = match (args[0].clone(), args[1].clone()) { + (ColumnarValue::Array(arr0), ColumnarValue::Array(arr1)) => (arr0, arr1), + (ColumnarValue::Array(arr0), ColumnarValue::Scalar(scalar)) => { + (arr0.clone(), scalar.to_array_of_size(arr0.len())?) + } + (ColumnarValue::Scalar(scalar), ColumnarValue::Array(arr1)) => { + (scalar.to_array_of_size(arr1.len())?, arr1) + } + (ColumnarValue::Scalar(scalar0), ColumnarValue::Scalar(scalar1)) => { + (scalar0.to_array_of_size(1)?, scalar1.to_array_of_size(1)?) + } + }; + if arr0.len() != arr1.len() { + return Err(DataFusionError::Internal( + "Two arrays length is not same".to_string(), + )); + } match (arr0.data_type(), arr1.data_type()) { (DataType::Binary, DataType::Binary) => { let arr0 = arr0.as_binary::(); let arr1 = arr1.as_binary::(); - let geom_len = std::cmp::min(arr0.geom_len(), arr1.geom_len()); let mut bool_vec = vec![]; - for i in 0..geom_len { + for i in 0..arr0.len() { bool_vec.push(intersects::(arr0, arr1, i)?); } Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(bool_vec)))) @@ -60,9 +75,8 @@ impl ScalarUDFImpl for IntersectsUdf { (DataType::LargeBinary, DataType::Binary) => { let arr0 = arr0.as_binary::(); let arr1 = arr1.as_binary::(); - let geom_len = std::cmp::min(arr0.geom_len(), arr1.geom_len()); let mut bool_vec = vec![]; - for i in 0..geom_len { + for i in 0..arr0.len() { bool_vec.push(intersects::(arr0, arr1, i)?); } Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(bool_vec)))) @@ -70,9 +84,8 @@ impl ScalarUDFImpl for IntersectsUdf { (DataType::Binary, DataType::LargeBinary) => { let arr0 = arr0.as_binary::(); let arr1 = arr1.as_binary::(); - let geom_len = std::cmp::min(arr0.geom_len(), arr1.geom_len()); let mut bool_vec = vec![]; - for i in 0..geom_len { + for i in 0..arr0.len() { bool_vec.push(intersects::(arr0, arr1, i)?); } Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(bool_vec)))) @@ -80,9 +93,8 @@ impl ScalarUDFImpl for IntersectsUdf { (DataType::LargeBinary, DataType::LargeBinary) => { let arr0 = arr0.as_binary::(); let arr1 = arr1.as_binary::(); - let geom_len = std::cmp::min(arr0.geom_len(), arr1.geom_len()); let mut bool_vec = vec![]; - for i in 0..geom_len { + for i in 0..arr0.len() { bool_vec.push(intersects::(arr0, arr1, i)?); } Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(bool_vec)))) @@ -134,9 +146,15 @@ fn intersects( #[cfg(test)] mod tests { use crate::function::{GeomFromTextUdf, IntersectsUdf}; + use crate::geo::GeometryArrayBuilder; use arrow::util::pretty::pretty_format_batches; + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::datasource::MemTable; use datafusion::logical_expr::ScalarUDF; use datafusion::prelude::SessionContext; + use geo::line_string; + use std::sync::Arc; #[tokio::test] async fn intersects() { @@ -158,4 +176,55 @@ mod tests { +-----------------------------------------------------------------------------------------------------+" ); } + + #[tokio::test] + async fn intersects_table() { + let ctx = SessionContext::new(); + ctx.register_udf(ScalarUDF::from(GeomFromTextUdf::new())); + ctx.register_udf(ScalarUDF::from(IntersectsUdf::new())); + + let schema = Arc::new(Schema::new(vec![Field::new( + "geom", + DataType::Binary, + true, + )])); + + let mut linestrint_vec = vec![]; + for i in 0..3 { + let i = i as f64; + let linestring = line_string![ + (x: i, y: i + 1.0), + (x: i + 2.0, y: i + 3.0), + (x: i + 4.0, y: i + 5.0), + ]; + linestrint_vec.push(Some(linestring)); + } + let builder: GeometryArrayBuilder = linestrint_vec.as_slice().into(); + let record = RecordBatch::try_new(schema.clone(), vec![Arc::new(builder.build())]).unwrap(); + + let mem_table = + MemTable::try_new(schema.clone(), vec![vec![record.clone()], vec![record]]).unwrap(); + ctx.register_table("geom_table", Arc::new(mem_table)) + .unwrap(); + + let df = ctx + .sql("select ST_Intersects(geom, ST_GeomFromText('POINT(0 1)')) from geom_table") + .await + .unwrap(); + assert_eq!( + pretty_format_batches(&df.collect().await.unwrap()) + .unwrap() + .to_string(), + "+--------------------------------------------------------------------+ +| ST_Intersects(geom_table.geom,ST_GeomFromText(Utf8(\"POINT(0 1)\"))) | ++--------------------------------------------------------------------+ +| true | +| false | +| false | +| true | +| false | +| false | ++--------------------------------------------------------------------+" + ); + } }