Skip to content

Commit

Permalink
Add ST_Extent
Browse files Browse the repository at this point in the history
  • Loading branch information
lewiszlw committed Mar 7, 2024
1 parent 4233ff4 commit b4053fc
Show file tree
Hide file tree
Showing 10 changed files with 275 additions and 39 deletions.
12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ criterion = { version = "0.5.1", features = ["async_tokio"] }
geoarrow = { git = "https://github.com/geoarrow/geoarrow-rs.git", rev = "0e4473e546248d2c2cbfb44df76d508660761261" }

[[bench]]
name = "geo_bench"
path = "benches/geo_bench.rs"
name = "geo_lib"
path = "benches/geo_lib.rs"
harness = false

[[bench]]
name = "geos_bench"
path = "benches/geos_bench.rs"
name = "geos_lib"
path = "benches/geos_lib.rs"
harness = false
required-features = ["geos"]

[[bench]]
name = "geoarrow_bench"
path = "benches/geoarrow_bench.rs"
name = "geoarrow"
path = "benches/geoarrow.rs"
harness = false
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ Add geo functionality extension to datafusion query engine.
**Goals**
1. Support multiple wkb dialects
2. Provide DataFusion user defined functions similar with PostGIS
3. Prefer using geos library if possible

P.S. Please see each function unit test to know how to use them.

Expand Down
4 changes: 2 additions & 2 deletions benches/geo_bench.rs → benches/geo_lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use datafusion_geo::function::{GeomFromTextUdf, IntersectsUdf};

mod util;

async fn geo_intersects(ctx: SessionContext, sql: &str) {
async fn geo_computation(ctx: SessionContext, sql: &str) {
#[cfg(feature = "geos")]
{
panic!("geo bench needs disabling geos feature flag")
Expand All @@ -21,7 +21,7 @@ fn criterion_benchmark(c: &mut Criterion) {
ctx.register_udf(ScalarUDF::from(GeomFromTextUdf::new()));
let sql = "select ST_Intersects(geom, ST_GeomFromText('POINT(10 11)')) from geom_table";
c.bench_function(&format!("geo_bench with sql: {}", sql), |b| {
b.to_async(&rt).iter(|| geo_intersects(ctx.clone(), sql))
b.to_async(&rt).iter(|| geo_computation(ctx.clone(), sql))
});
}

Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions benches/geos_bench.rs → benches/geos_lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use datafusion_geo::function::{GeomFromTextUdf, IntersectsUdf};

mod util;

async fn geos_intersects(ctx: SessionContext, sql: &str) {
async fn geos_computation(ctx: SessionContext, sql: &str) {
#[cfg(not(feature = "geos"))]
{
panic!("geo bench needs disabling geos feature flag")
panic!("geos bench needs enabling geos feature flag")
}
let df = ctx.sql(sql).await.unwrap();
let _ = df.collect().await.unwrap();
Expand All @@ -21,7 +21,7 @@ fn criterion_benchmark(c: &mut Criterion) {
ctx.register_udf(ScalarUDF::from(GeomFromTextUdf::new()));
let sql = "select ST_Intersects(geom, ST_GeomFromText('POINT(10 11)')) from geom_table";
c.bench_function(&format!("geos_bench with sql: {}", sql), |b| {
b.to_async(&rt).iter(|| geos_intersects(ctx.clone(), sql))
b.to_async(&rt).iter(|| geos_computation(ctx.clone(), sql))
});
}

Expand Down
12 changes: 6 additions & 6 deletions src/function/box2d.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::geo::{build_box2d_array, Box2D, GeometryArray};
use crate::geo::{build_box2d_array, Box2d, GeometryArray};
use arrow_array::cast::AsArray;
use arrow_array::Array;
use arrow_schema::DataType;
Expand Down Expand Up @@ -40,33 +40,33 @@ impl ScalarUDFImpl for Box2dUdf {
}

fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(Box2D::data_type())
Ok(Box2d::data_type())
}

fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result<ColumnarValue> {
let arr = args[0].clone().into_array(1)?;
match arr.data_type() {
DataType::Binary => {
let wkb_arr = arr.as_binary::<i32>();
let mut box2d_vec: Vec<Option<Box2D>> = vec![];
let mut box2d_vec: Vec<Option<Box2d>> = vec![];
for i in 0..wkb_arr.geom_len() {
box2d_vec.push(
wkb_arr
.geo_value(i)?
.and_then(|geom| geom.bounding_rect().map(Box2D::from)),
.and_then(|geom| geom.bounding_rect().map(Box2d::from)),
);
}
let arr = build_box2d_array(box2d_vec);
Ok(ColumnarValue::Array(Arc::new(arr)))
}
DataType::LargeBinary => {
let wkb_arr = arr.as_binary::<i64>();
let mut box2d_vec: Vec<Option<Box2D>> = vec![];
let mut box2d_vec: Vec<Option<Box2d>> = vec![];
for i in 0..wkb_arr.geom_len() {
box2d_vec.push(
wkb_arr
.geo_value(i)?
.and_then(|geom| geom.bounding_rect().map(Box2D::from)),
.and_then(|geom| geom.bounding_rect().map(Box2d::from)),
);
}
let arr = build_box2d_array(box2d_vec);
Expand Down
229 changes: 229 additions & 0 deletions src/function/extent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
use crate::geo::{Box2d, GeometryArray};
use crate::DFResult;
use arrow_array::cast::AsArray;
use arrow_array::{Array, ArrayRef, GenericBinaryArray, OffsetSizeTrait};
use arrow_schema::DataType;
use datafusion_common::ScalarValue;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
use geo::BoundingRect;
use std::any::Any;

#[derive(Debug)]
pub struct ExtentUdaf {
signature: Signature,
}

impl ExtentUdaf {
pub fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Binary, DataType::LargeBinary],
Volatility::Immutable,
),
}
}
}

impl AggregateUDFImpl for ExtentUdaf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
// uadf not support alias
"st_extent"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(Box2d::data_type())
}

fn accumulator(&self, _arg: &DataType) -> datafusion_common::Result<Box<dyn Accumulator>> {
Ok(Box::new(ExtentAccumulator::new()))
}

fn state_type(&self, _return_type: &DataType) -> datafusion_common::Result<Vec<DataType>> {
Ok(vec![Box2d::data_type()])
}
}

impl Default for ExtentUdaf {
fn default() -> Self {
Self::new()
}
}

#[derive(Debug)]
pub struct ExtentAccumulator {
box2d: Box2d,
}

impl ExtentAccumulator {
pub fn new() -> Self {
Self {
box2d: Box2d {
xmin: f64::MAX,
ymin: f64::MAX,
xmax: f64::MIN,
ymax: f64::MIN,
},
}
}
}

impl Accumulator for ExtentAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = &values[0];
match arr.data_type() {
DataType::Binary => {
let wkb_arr = arr.as_binary::<i32>();
let box2d = compute_extent::<i32>(wkb_arr)?;
self.box2d = compute_bounding_box2d(self.box2d.clone(), box2d);
}
DataType::LargeBinary => {
let wkb_arr = arr.as_binary::<i64>();
let box2d = compute_extent::<i64>(wkb_arr)?;
self.box2d = compute_bounding_box2d(self.box2d.clone(), box2d);
}
_ => unreachable!(),
}
Ok(())
}

fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
Ok(self.box2d.clone().into())
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}

fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
Ok(vec![self.box2d.clone().into()])
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
if states.is_empty() {
return Ok(());
}
let arr = &states[0];
(0..arr.len()).try_for_each(|index| {
let v = states
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<datafusion_common::Result<Vec<_>>>()?;
if let ScalarValue::Struct(arr) = &v[0] {
if let Some(box2d) = Box2d::value(arr, 0)? {
self.box2d = compute_bounding_box2d(self.box2d.clone(), box2d);
}
} else {
unreachable!("")
}
Ok(())
})
}
}

fn compute_extent<O: OffsetSizeTrait>(arr: &GenericBinaryArray<O>) -> DFResult<Box2d> {
let mut box2d = Box2d {
xmin: f64::MAX,
ymin: f64::MAX,
xmax: f64::MIN,
ymax: f64::MIN,
};
for i in 0..arr.geom_len() {
if let Some(value) = arr
.geo_value(i)?
.and_then(|geom| geom.bounding_rect().map(Box2d::from))
{
box2d = compute_bounding_box2d(box2d, value);
}
}
Ok(box2d)
}

fn compute_bounding_box2d(b0: Box2d, b1: Box2d) -> Box2d {
let xmin = b0.xmin.min(b1.xmin);
let ymin = b0.ymin.min(b1.ymin);
let xmax = b0.xmax.max(b1.xmax);
let ymax = b0.ymax.max(b1.ymax);
Box2d {
xmin,
ymin,
xmax,
ymax,
}
}

#[cfg(test)]
mod tests {
use crate::function::extent::ExtentUdaf;
use crate::geo::GeometryArrayBuilder;
use arrow::util::pretty::pretty_format_batches;
use arrow_array::{RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::MemTable;
use datafusion::prelude::SessionContext;
use datafusion_expr::AggregateUDF;
use geo::line_string;
use std::sync::Arc;

#[tokio::test]
async fn extent() {
let schema = Arc::new(Schema::new(vec![
Field::new("geom", DataType::Binary, true),
Field::new("name", DataType::Utf8, true),
]));

let mut linestrint_vec = vec![];
for i in 0..4 {
let i = i as f64;
let linestring = line_string![
(x: i, y: i + 1.0),
(x: i + 2.0, y: i + 3.0),
(x: i + 4.0, y: i + 5.0),
];
linestrint_vec.push(Some(linestring));
}
let builder: GeometryArrayBuilder<i32> = linestrint_vec.as_slice().into();

let record = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(builder.build()),
Arc::new(StringArray::from(vec!["a", "a", "b", "b"])),
],
)
.unwrap();

let mem_table = MemTable::try_new(schema.clone(), vec![vec![record]]).unwrap();

let ctx = SessionContext::new();
ctx.register_table("geom_table", Arc::new(mem_table))
.unwrap();
ctx.register_udaf(AggregateUDF::from(ExtentUdaf::new()));
let df = ctx
.sql("select ST_Extent(geom), name from geom_table group by name")
.await
.unwrap();
assert_eq!(
pretty_format_batches(&df.collect().await.unwrap())
.unwrap()
.to_string(),
"+----------------------------------------------+------+
| st_extent(geom_table.geom) | name |
+----------------------------------------------+------+
| {xmin: 2.0, ymin: 3.0, xmax: 7.0, ymax: 8.0} | b |
| {xmin: 0.0, ymin: 1.0, xmax: 5.0, ymax: 6.0} | a |
+----------------------------------------------+------+"
);
}
}
2 changes: 1 addition & 1 deletion src/function/intersects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ fn intersects<O: OffsetSizeTrait, F: OffsetSizeTrait>(
arr0: &GenericBinaryArray<O>,
arr1: &GenericBinaryArray<F>,
) -> DFResult<ColumnarValue> {
let bool_vec = (0..arr0.len())
let bool_vec = (0..arr0.geom_len())
.into_par_iter()
.map(|geom_index| {
#[cfg(feature = "geos")]
Expand Down
1 change: 1 addition & 0 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod as_ewkt;
mod as_text;
mod box2d;
mod extent;
mod geom_from_text;
mod geom_from_wkb;
mod geometry_type;
Expand Down
Loading

0 comments on commit b4053fc

Please sign in to comment.