Skip to content

Commit

Permalink
Fix: generate_series function support string type (apache#12002)
Browse files Browse the repository at this point in the history
* fix: sqllogictest

* Revert "fix: sqllogictest"

This reverts commit 4957a1d.

* fix: sqllogictest

* remove any type signature

* coerce type from null  to date32

* fmt

* slt

* Revert "coerce type from null  to date32"

This reverts commit bccdc2e.

* replace type coerce by `coerce_types` method

* fmt

* fix underscored param
  • Loading branch information
getChan authored Aug 19, 2024
1 parent a91be04 commit 574dfeb
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 35 deletions.
93 changes: 62 additions & 31 deletions datafusion/functions-nested/src/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ use arrow::datatypes::{DataType, Field};
use arrow_array::types::{Date32Type, IntervalMonthDayNanoType};
use arrow_array::NullArray;
use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer};
use arrow_schema::DataType::{Date32, Int64, Interval, List};
use arrow_schema::DataType::*;
use arrow_schema::IntervalUnit::MonthDayNano;
use datafusion_common::cast::{as_date32_array, as_int64_array, as_interval_mdn_array};
use datafusion_common::{exec_err, not_impl_datafusion_err, Result};
use datafusion_expr::{
ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use itertools::Itertools;
use std::any::Any;
use std::iter::from_fn;
use std::sync::Arc;
Expand All @@ -49,16 +48,7 @@ pub(super) struct Range {
impl Range {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![Int64]),
TypeSignature::Exact(vec![Int64, Int64]),
TypeSignature::Exact(vec![Int64, Int64, Int64]),
TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]),
TypeSignature::Any(3),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![],
}
}
Expand All @@ -75,9 +65,34 @@ impl ScalarUDFImpl for Range {
&self.signature
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
arg_types
.iter()
.map(|arg_type| match arg_type {
Null => Ok(Null),
Int8 => Ok(Int64),
Int16 => Ok(Int64),
Int32 => Ok(Int64),
Int64 => Ok(Int64),
UInt8 => Ok(Int64),
UInt16 => Ok(Int64),
UInt32 => Ok(Int64),
UInt64 => Ok(Int64),
Timestamp(_, _) => Ok(Date32),
Date32 => Ok(Date32),
Date64 => Ok(Date32),
Utf8 => Ok(Date32),
LargeUtf8 => Ok(Date32),
Utf8View => Ok(Date32),
Interval(_) => Ok(Interval(MonthDayNano)),
_ => exec_err!("Unsupported DataType"),
})
.try_collect()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.iter().any(|t| t.eq(&DataType::Null)) {
Ok(DataType::Null)
if arg_types.iter().any(|t| t.is_null()) {
Ok(Null)
} else {
Ok(List(Arc::new(Field::new(
"item",
Expand All @@ -88,7 +103,7 @@ impl ScalarUDFImpl for Range {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.iter().any(|arg| arg.data_type() == DataType::Null) {
if args.iter().any(|arg| arg.data_type().is_null()) {
return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1))));
}
match args[0].data_type() {
Expand Down Expand Up @@ -120,16 +135,7 @@ pub(super) struct GenSeries {
impl GenSeries {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![Int64]),
TypeSignature::Exact(vec![Int64, Int64]),
TypeSignature::Exact(vec![Int64, Int64, Int64]),
TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]),
TypeSignature::Any(3),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![],
}
}
Expand All @@ -146,9 +152,34 @@ impl ScalarUDFImpl for GenSeries {
&self.signature
}

fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
_arg_types
.iter()
.map(|arg_type| match arg_type {
Null => Ok(Null),
Int8 => Ok(Int64),
Int16 => Ok(Int64),
Int32 => Ok(Int64),
Int64 => Ok(Int64),
UInt8 => Ok(Int64),
UInt16 => Ok(Int64),
UInt32 => Ok(Int64),
UInt64 => Ok(Int64),
Timestamp(_, _) => Ok(Date32),
Date32 => Ok(Date32),
Date64 => Ok(Date32),
Utf8 => Ok(Date32),
LargeUtf8 => Ok(Date32),
Utf8View => Ok(Date32),
Interval(_) => Ok(Interval(MonthDayNano)),
_ => exec_err!("Unsupported DataType"),
})
.try_collect()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.iter().any(|t| t.eq(&DataType::Null)) {
Ok(DataType::Null)
if arg_types.iter().any(|t| t.is_null()) {
Ok(Null)
} else {
Ok(List(Arc::new(Field::new(
"item",
Expand All @@ -159,15 +190,15 @@ impl ScalarUDFImpl for GenSeries {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.iter().any(|arg| arg.data_type() == DataType::Null) {
if args.iter().any(|arg| arg.data_type().is_null()) {
return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1))));
}
match args[0].data_type() {
Int64 => make_scalar_function(|args| gen_range_inner(args, true))(args),
Date32 => make_scalar_function(|args| gen_range_date(args, true))(args),
dt => {
exec_err!(
"unsupported type for range. Expected Int64 or Date32, got: {}",
"unsupported type for gen_series. Expected Int64 or Date32, got: {}",
dt
)
}
Expand Down
9 changes: 5 additions & 4 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5804,7 +5804,7 @@ select generate_series(5),
----
[0, 1, 2, 3, 4, 5] [2, 3, 4, 5] [2, 5, 8] [1, 2, 3, 4, 5] [5, 4, 3, 2, 1] [10, 7, 4] [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01, 1993-03-01] [1993-02-01, 1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 1993-01-27, 1993-01-26, 1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 1993-01-21, 1993-01-20, 1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 1993-01-15, 1993-01-14, 1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 1993-01-09, 1993-01-08, 1993-01-07, 1993-01-06, 1993-01-05, 1993-01-04, 1993-01-03, 1993-01-02, 1993-01-01] [1989-04-01, 1990-04-01, 1991-04-01, 1992-04-01]

query error DataFusion error: Execution error: unsupported type for range. Expected Int64 or Date32, got: Timestamp\(Nanosecond, None\)
query error DataFusion error: Execution error: Cannot generate date range less than 1 day\.
select generate_series('2021-01-01'::timestamp, '2021-01-02'::timestamp, INTERVAL '1' HOUR);

## should return NULL
Expand Down Expand Up @@ -5936,11 +5936,12 @@ select generate_series(start, '1993-03-01'::date, INTERVAL '1 year') from date_t


# https://github.com/apache/datafusion/issues/11922
query error
query ?
select generate_series(start, '1993-03-01', INTERVAL '1 year') from date_table;
----
DataFusion error: Internal error: could not cast value to arrow_array::array::primitive_array::PrimitiveArray<arrow_array::types::Date32Type>.
This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker
[1992-01-01, 1993-01-01]
[1993-02-01]
[1989-04-01, 1990-04-01, 1991-04-01, 1992-04-01]


## array_except
Expand Down

0 comments on commit 574dfeb

Please sign in to comment.