Skip to content

Commit

Permalink
Fix ST_Intersects bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lewiszlw committed Mar 6, 2024
1 parent 5bfc4b0 commit 346cc81
Showing 1 changed file with 79 additions and 10 deletions.
89 changes: 79 additions & 10 deletions src/function/intersects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::DFResult;
use arrow_array::cast::AsArray;
use arrow_array::{Array, BooleanArray, GenericBinaryArray, OffsetSizeTrait};
use arrow_schema::DataType;
use datafusion_common::DataFusionError;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
Expand Down Expand Up @@ -44,45 +45,56 @@ impl ScalarUDFImpl for IntersectsUdf {
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
let arr0 = args[0].clone().into_array(1)?;
let arr1 = args[1].clone().into_array(1)?;
let (arr0, arr1) = match (args[0].clone(), args[1].clone()) {
(ColumnarValue::Array(arr0), ColumnarValue::Array(arr1)) => (arr0, arr1),
(ColumnarValue::Array(arr0), ColumnarValue::Scalar(scalar)) => {
(arr0.clone(), scalar.to_array_of_size(arr0.len())?)
}
(ColumnarValue::Scalar(scalar), ColumnarValue::Array(arr1)) => {
(scalar.to_array_of_size(arr1.len())?, arr1)
}
(ColumnarValue::Scalar(scalar0), ColumnarValue::Scalar(scalar1)) => {
(scalar0.to_array_of_size(1)?, scalar1.to_array_of_size(1)?)
}
};
if arr0.len() != arr1.len() {
return Err(DataFusionError::Internal(
"Two arrays length is not same".to_string(),
));
}
match (arr0.data_type(), arr1.data_type()) {
(DataType::Binary, DataType::Binary) => {
let arr0 = arr0.as_binary::<i32>();
let arr1 = arr1.as_binary::<i32>();
let geom_len = std::cmp::min(arr0.geom_len(), arr1.geom_len());
let mut bool_vec = vec![];
for i in 0..geom_len {
for i in 0..arr0.len() {
bool_vec.push(intersects::<i32, i32>(arr0, arr1, i)?);
}
Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(bool_vec))))
}
(DataType::LargeBinary, DataType::Binary) => {
let arr0 = arr0.as_binary::<i64>();
let arr1 = arr1.as_binary::<i32>();
let geom_len = std::cmp::min(arr0.geom_len(), arr1.geom_len());
let mut bool_vec = vec![];
for i in 0..geom_len {
for i in 0..arr0.len() {
bool_vec.push(intersects::<i64, i32>(arr0, arr1, i)?);
}
Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(bool_vec))))
}
(DataType::Binary, DataType::LargeBinary) => {
let arr0 = arr0.as_binary::<i32>();
let arr1 = arr1.as_binary::<i64>();
let geom_len = std::cmp::min(arr0.geom_len(), arr1.geom_len());
let mut bool_vec = vec![];
for i in 0..geom_len {
for i in 0..arr0.len() {
bool_vec.push(intersects::<i32, i64>(arr0, arr1, i)?);
}
Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(bool_vec))))
}
(DataType::LargeBinary, DataType::LargeBinary) => {
let arr0 = arr0.as_binary::<i64>();
let arr1 = arr1.as_binary::<i64>();
let geom_len = std::cmp::min(arr0.geom_len(), arr1.geom_len());
let mut bool_vec = vec![];
for i in 0..geom_len {
for i in 0..arr0.len() {
bool_vec.push(intersects::<i64, i64>(arr0, arr1, i)?);
}
Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(bool_vec))))
Expand Down Expand Up @@ -134,9 +146,15 @@ fn intersects<O: OffsetSizeTrait, F: OffsetSizeTrait>(
#[cfg(test)]
mod tests {
use crate::function::{GeomFromTextUdf, IntersectsUdf};
use crate::geo::GeometryArrayBuilder;
use arrow::util::pretty::pretty_format_batches;
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::MemTable;
use datafusion::logical_expr::ScalarUDF;
use datafusion::prelude::SessionContext;
use geo::line_string;
use std::sync::Arc;

#[tokio::test]
async fn intersects() {
Expand All @@ -158,4 +176,55 @@ mod tests {
+-----------------------------------------------------------------------------------------------------+"
);
}

#[tokio::test]
async fn intersects_table() {
let ctx = SessionContext::new();
ctx.register_udf(ScalarUDF::from(GeomFromTextUdf::new()));
ctx.register_udf(ScalarUDF::from(IntersectsUdf::new()));

let schema = Arc::new(Schema::new(vec![Field::new(
"geom",
DataType::Binary,
true,
)]));

let mut linestrint_vec = vec![];
for i in 0..3 {
let i = i as f64;
let linestring = line_string![
(x: i, y: i + 1.0),
(x: i + 2.0, y: i + 3.0),
(x: i + 4.0, y: i + 5.0),
];
linestrint_vec.push(Some(linestring));
}
let builder: GeometryArrayBuilder<i32> = linestrint_vec.as_slice().into();
let record = RecordBatch::try_new(schema.clone(), vec![Arc::new(builder.build())]).unwrap();

let mem_table =
MemTable::try_new(schema.clone(), vec![vec![record.clone()], vec![record]]).unwrap();
ctx.register_table("geom_table", Arc::new(mem_table))
.unwrap();

let df = ctx
.sql("select ST_Intersects(geom, ST_GeomFromText('POINT(0 1)')) from geom_table")
.await
.unwrap();
assert_eq!(
pretty_format_batches(&df.collect().await.unwrap())
.unwrap()
.to_string(),
"+--------------------------------------------------------------------+
| ST_Intersects(geom_table.geom,ST_GeomFromText(Utf8(\"POINT(0 1)\"))) |
+--------------------------------------------------------------------+
| true |
| false |
| false |
| true |
| false |
| false |
+--------------------------------------------------------------------+"
);
}
}

0 comments on commit 346cc81

Please sign in to comment.