Skip to content

Commit

Permalink
Add support for ODBC functions (#1585)
Browse files Browse the repository at this point in the history
  • Loading branch information
iffyio authored Dec 11, 2024
1 parent 04271b0 commit a13f8c6
Show file tree
Hide file tree
Showing 15 changed files with 142 additions and 6 deletions.
17 changes: 17 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5523,6 +5523,15 @@ impl fmt::Display for CloseCursor {
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Function {
pub name: ObjectName,
/// Flags whether this function call uses the [ODBC syntax].
///
/// Example:
/// ```sql
/// SELECT {fn CONCAT('foo', 'bar')}
/// ```
///
/// [ODBC syntax]: https://learn.microsoft.com/en-us/sql/odbc/reference/develop-app/scalar-function-calls?view=sql-server-2017
pub uses_odbc_syntax: bool,
/// The parameters to the function, including any options specified within the
/// delimiting parentheses.
///
Expand Down Expand Up @@ -5561,6 +5570,10 @@ pub struct Function {

impl fmt::Display for Function {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.uses_odbc_syntax {
write!(f, "{{fn ")?;
}

write!(f, "{}{}{}", self.name, self.parameters, self.args)?;

if !self.within_group.is_empty() {
Expand All @@ -5583,6 +5596,10 @@ impl fmt::Display for Function {
write!(f, " OVER {o}")?;
}

if self.uses_odbc_syntax {
write!(f, "}}")?;
}

Ok(())
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,7 @@ impl Spanned for Function {
fn span(&self) -> Span {
let Function {
name,
uses_odbc_syntax: _,
parameters,
args,
filter,
Expand Down
1 change: 1 addition & 0 deletions src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ where
/// let old_expr = std::mem::replace(expr, Expr::Value(Value::Null));
/// *expr = Expr::Function(Function {
/// name: ObjectName(vec![Ident::new("f")]),
/// uses_odbc_syntax: false,
/// args: FunctionArguments::List(FunctionArgumentList {
/// duplicate_treatment: None,
/// args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
Expand Down
1 change: 1 addition & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ define_keywords!(
FLOAT8,
FLOOR,
FLUSH,
FN,
FOLLOWING,
FOR,
FORCE,
Expand Down
65 changes: 59 additions & 6 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ impl<'a> Parser<'a> {
{
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![w.to_ident(w_span)]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::None,
null_treatment: None,
Expand Down Expand Up @@ -1111,6 +1112,7 @@ impl<'a> Parser<'a> {
self.expect_token(&Token::RParen)?;
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![w.to_ident(w_span)]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::Subquery(query),
filter: None,
Expand Down Expand Up @@ -1408,9 +1410,9 @@ impl<'a> Parser<'a> {
self.prev_token();
Ok(Expr::Value(self.parse_value()?))
}
Token::LBrace if self.dialect.supports_dictionary_syntax() => {
Token::LBrace => {
self.prev_token();
self.parse_duckdb_struct_literal()
self.parse_lbrace_expr()
}
_ => self.expected("an expression", next_token),
}?;
Expand Down Expand Up @@ -1509,23 +1511,46 @@ impl<'a> Parser<'a> {
}
}

/// Tries to parse the body of an [ODBC function] call.
/// i.e. without the enclosing braces
///
/// ```sql
/// fn myfunc(1,2,3)
/// ```
///
/// [ODBC function]: https://learn.microsoft.com/en-us/sql/odbc/reference/develop-app/scalar-function-calls?view=sql-server-2017
fn maybe_parse_odbc_fn_body(&mut self) -> Result<Option<Expr>, ParserError> {
self.maybe_parse(|p| {
p.expect_keyword(Keyword::FN)?;
let fn_name = p.parse_object_name(false)?;
let mut fn_call = p.parse_function_call(fn_name)?;
fn_call.uses_odbc_syntax = true;
Ok(Expr::Function(fn_call))
})
}

pub fn parse_function(&mut self, name: ObjectName) -> Result<Expr, ParserError> {
self.parse_function_call(name).map(Expr::Function)
}

fn parse_function_call(&mut self, name: ObjectName) -> Result<Function, ParserError> {
self.expect_token(&Token::LParen)?;

// Snowflake permits a subquery to be passed as an argument without
// an enclosing set of parens if it's the only argument.
if dialect_of!(self is SnowflakeDialect) && self.peek_sub_query() {
let subquery = self.parse_query()?;
self.expect_token(&Token::RParen)?;
return Ok(Expr::Function(Function {
return Ok(Function {
name,
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::Subquery(subquery),
filter: None,
null_treatment: None,
over: None,
within_group: vec![],
}));
});
}

let mut args = self.parse_function_argument_list()?;
Expand Down Expand Up @@ -1584,15 +1609,16 @@ impl<'a> Parser<'a> {
None
};

Ok(Expr::Function(Function {
Ok(Function {
name,
uses_odbc_syntax: false,
parameters,
args: FunctionArguments::List(args),
null_treatment,
filter,
over,
within_group,
}))
})
}

/// Optionally parses a null treatment clause.
Expand All @@ -1619,6 +1645,7 @@ impl<'a> Parser<'a> {
};
Ok(Expr::Function(Function {
name,
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args,
filter: None,
Expand Down Expand Up @@ -2211,6 +2238,31 @@ impl<'a> Parser<'a> {
}
}

/// Parse expression types that start with a left brace '{'.
/// Examples:
/// ```sql
/// -- Dictionary expr.
/// {'key1': 'value1', 'key2': 'value2'}
///
/// -- Function call using the ODBC syntax.
/// { fn CONCAT('foo', 'bar') }
/// ```
fn parse_lbrace_expr(&mut self) -> Result<Expr, ParserError> {
let token = self.expect_token(&Token::LBrace)?;

if let Some(fn_expr) = self.maybe_parse_odbc_fn_body()? {
self.expect_token(&Token::RBrace)?;
return Ok(fn_expr);
}

if self.dialect.supports_dictionary_syntax() {
self.prev_token(); // Put back the '{'
return self.parse_duckdb_struct_literal();
}

self.expected("an expression", token)
}

/// Parses fulltext expressions [`sqlparser::ast::Expr::MatchAgainst`]
///
/// # Errors
Expand Down Expand Up @@ -7578,6 +7630,7 @@ impl<'a> Parser<'a> {
} else {
Ok(Statement::Call(Function {
name: object_name,
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::None,
over: None,
Expand Down
1 change: 1 addition & 0 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ pub fn join(relation: TableFactor) -> Join {
pub fn call(function: &str, args: impl IntoIterator<Item = Expr>) -> Expr {
Expr::Function(Function {
name: ObjectName(vec![Ident::new(function)]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down
4 changes: 4 additions & 0 deletions tests/sqlparser_clickhouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ fn parse_delimited_identifiers() {
assert_eq!(
&Expr::Function(Function {
name: ObjectName(vec![Ident::with_quote('"', "myfun")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -821,6 +822,7 @@ fn parse_create_table_with_variant_default_expressions() {
name: None,
option: ColumnOption::Materialized(Expr::Function(Function {
name: ObjectName(vec![Ident::new("now")]),
uses_odbc_syntax: false,
args: FunctionArguments::List(FunctionArgumentList {
args: vec![],
duplicate_treatment: None,
Expand All @@ -842,6 +844,7 @@ fn parse_create_table_with_variant_default_expressions() {
name: None,
option: ColumnOption::Ephemeral(Some(Expr::Function(Function {
name: ObjectName(vec![Ident::new("now")]),
uses_odbc_syntax: false,
args: FunctionArguments::List(FunctionArgumentList {
args: vec![],
duplicate_treatment: None,
Expand Down Expand Up @@ -872,6 +875,7 @@ fn parse_create_table_with_variant_default_expressions() {
name: None,
option: ColumnOption::Alias(Expr::Function(Function {
name: ObjectName(vec![Ident::new("toString")]),
uses_odbc_syntax: false,
args: FunctionArguments::List(FunctionArgumentList {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Identifier(Ident::new("c"))
Expand Down
43 changes: 43 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,7 @@ fn parse_select_count_wildcard() {
assert_eq!(
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("COUNT")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand All @@ -1130,6 +1131,7 @@ fn parse_select_count_distinct() {
assert_eq!(
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("COUNT")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: Some(DuplicateTreatment::Distinct),
Expand Down Expand Up @@ -2366,6 +2368,7 @@ fn parse_select_having() {
Some(Expr::BinaryOp {
left: Box::new(Expr::Function(Function {
name: ObjectName(vec![Ident::new("COUNT")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -2396,6 +2399,7 @@ fn parse_select_qualify() {
Some(Expr::BinaryOp {
left: Box::new(Expr::Function(Function {
name: ObjectName(vec![Ident::new("ROW_NUMBER")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -2802,6 +2806,7 @@ fn parse_listagg() {
assert_eq!(
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("LISTAGG")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: Some(DuplicateTreatment::Distinct),
Expand Down Expand Up @@ -4603,6 +4608,7 @@ fn parse_named_argument_function() {
assert_eq!(
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("FUN")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -4642,6 +4648,7 @@ fn parse_named_argument_function_with_eq_operator() {
assert_eq!(
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("FUN")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -4716,6 +4723,7 @@ fn parse_window_functions() {
assert_eq!(
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("row_number")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -4846,6 +4854,7 @@ fn test_parse_named_window() {
quote_style: None,
span: Span::empty(),
}]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -4880,6 +4889,7 @@ fn test_parse_named_window() {
quote_style: None,
span: Span::empty(),
}]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -9008,6 +9018,7 @@ fn parse_time_functions() {
let select = verified_only_select(&sql);
let select_localtime_func_call_ast = Function {
name: ObjectName(vec![Ident::new(func_name)]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -10021,6 +10032,7 @@ fn parse_call() {
assert_eq!(
verified_stmt("CALL my_procedure('a')"),
Statement::Call(Function {
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -10511,6 +10523,7 @@ fn test_selective_aggregation() {
vec![
SelectItem::UnnamedExpr(Expr::Function(Function {
name: ObjectName(vec![Ident::new("ARRAY_AGG")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand All @@ -10529,6 +10542,7 @@ fn test_selective_aggregation() {
SelectItem::ExprWithAlias {
expr: Expr::Function(Function {
name: ObjectName(vec![Ident::new("ARRAY_AGG")]),
uses_odbc_syntax: false,
parameters: FunctionArguments::None,
args: FunctionArguments::List(FunctionArgumentList {
duplicate_treatment: None,
Expand Down Expand Up @@ -10968,6 +10982,35 @@ fn insert_into_with_parentheses() {
dialects.verified_stmt(r#"INSERT INTO t1 ("select", name) (SELECT t2.name FROM t2)"#);
}

#[test]
fn parse_odbc_scalar_function() {
let select = verified_only_select("SELECT {fn my_func(1, 2)}");
let Expr::Function(Function {
name,
uses_odbc_syntax,
args,
..
}) = expr_from_projection(only(&select.projection))
else {
unreachable!("expected function")
};
assert_eq!(name, &ObjectName(vec![Ident::new("my_func")]));
assert!(uses_odbc_syntax);
matches!(args, FunctionArguments::List(l) if l.args.len() == 2);

verified_stmt("SELECT {fn fna()} AS foo, fnb(1)");

// Testing invalid SQL with any-one dialect is intentional.
// Depending on dialect flags the error message may be different.
let pg = TestedDialects::new(vec![Box::new(PostgreSqlDialect {})]);
assert_eq!(
pg.parse_sql_statements("SELECT {fn2 my_func()}")
.unwrap_err()
.to_string(),
"sql parser error: Expected: an expression, found: {"
);
}

#[test]
fn test_dictionary_syntax() {
fn check(sql: &str, expect: Expr) {
Expand Down
Loading

0 comments on commit a13f8c6

Please sign in to comment.