From 6a18bb76ed3bd3814397778aa3c00de6c51e8358 Mon Sep 17 00:00:00 2001 From: "R. C. Howell" Date: Thu, 28 Sep 2023 10:09:11 -0700 Subject: [PATCH] Adds configurable AST -> SQL printer (#1183) --- CHANGELOG.md | 1 + .../src/main/kotlin/org/partiql/ast/Ast.kt | 13 + .../main/kotlin/org/partiql/ast/sql/Sql.kt | 39 + .../kotlin/org/partiql/ast/sql/SqlBlock.kt | 89 + .../kotlin/org/partiql/ast/sql/SqlDialect.kt | 766 ++++++++ .../kotlin/org/partiql/ast/sql/SqlLayout.kt | 96 + .../org/partiql/ast/sql/SqlBlockWriterTest.kt | 77 + .../org/partiql/ast/sql/SqlDialectTest.kt | 1656 +++++++++++++++++ .../value/io/PartiQLValueTextWriter.kt | 2 +- 9 files changed, 2738 insertions(+), 1 deletion(-) create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlBlock.kt create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlLayout.kt create mode 100644 partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlBlockWriterTest.kt create mode 100644 partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index bb65d261d..158497cc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Adds overridden `toString()` method for Sprout-generated code. - Adds CURRENT_DATE session variable to PartiQL.g4 and PartiQLParser +- Adds configurable AST to SQL pretty printer. Usage in Java `AstKt.sql(ast)` or in Kotlin `ast.sql()`. ### Changed diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt index 380a92465..fa3434d1a 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt @@ -1,6 +1,10 @@ package org.partiql.ast import org.partiql.ast.builder.AstFactoryImpl +import org.partiql.ast.sql.SqlBlock +import org.partiql.ast.sql.SqlDialect +import org.partiql.ast.sql.SqlLayout +import org.partiql.ast.sql.sql /** * Singleton instance of the default factory; also accessible via `AstFactory.DEFAULT`. @@ -13,3 +17,12 @@ object Ast : AstBaseFactory() public abstract class AstBaseFactory : AstFactoryImpl() { // internal default overrides here } + +/** + * Pretty-print this [AstNode] as SQL text with the given [SqlLayout] + */ +@JvmOverloads +public fun AstNode.sql( + layout: SqlLayout = SqlLayout.DEFAULT, + dialect: SqlDialect = SqlDialect.PARTIQL, +): String = accept(dialect, SqlBlock.Nil).sql(layout) diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt new file mode 100644 index 000000000..d360f59e3 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt @@ -0,0 +1,39 @@ +package org.partiql.ast.sql + +// a <> b <-> a concat b + +internal infix fun SqlBlock.concat(rhs: SqlBlock): SqlBlock = link(this, rhs) + +internal infix fun SqlBlock.concat(text: String): SqlBlock = link(this, text(text)) + +internal infix operator fun SqlBlock.plus(rhs: SqlBlock): SqlBlock = link(this, rhs) + +internal infix operator fun SqlBlock.plus(text: String): SqlBlock = link(this, text(text)) + +// Shorthand + +internal val NIL = SqlBlock.Nil + +internal val NL = SqlBlock.NL + +internal fun text(text: String) = SqlBlock.Text(text) + +internal fun link(lhs: SqlBlock, rhs: SqlBlock) = SqlBlock.Link(lhs, rhs) + +internal fun nest(block: () -> SqlBlock) = SqlBlock.Nest(block()) + +internal fun list(start: String?, end: String?, delimiter: String? = ",", items: () -> List): SqlBlock { + var h: SqlBlock = NIL + h = if (start != null) h + start else h + h += nest { + val kids = items() + var list: SqlBlock = NIL + kids.foldIndexed(list) { i, a, item -> + list += item + list = if (delimiter != null && (i + 1) < kids.size) a + delimiter else a + list + } + } + h = if (end != null) h + end else h + return h +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlBlock.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlBlock.kt new file mode 100644 index 000000000..c163e8998 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlBlock.kt @@ -0,0 +1,89 @@ +package org.partiql.ast.sql + +/** + * Write this [SqlBlock] tree as SQL text with the given [SqlLayout]. + * + * @param layout SQL formatting ruleset + * @return SQL text + */ +public fun SqlBlock.sql(layout: SqlLayout = SqlLayout.DEFAULT): String = layout.format(this) + +/** + * Representation of some textual corpus; akin to Wadler's "A prettier printer" Document type. + */ +sealed interface SqlBlock { + + public override fun toString(): String + + public fun accept(visitor: BlockVisitor, ctx: C): R + + public object Nil : SqlBlock { + + override fun toString() = "" + + override fun accept(visitor: BlockVisitor, ctx: C): R = visitor.visitNil(this, ctx) + } + + public object NL : SqlBlock { + + override fun toString() = "\n" + + override fun accept(visitor: BlockVisitor, ctx: C): R = visitor.visitNewline(this, ctx) + } + + public class Text(val text: String) : SqlBlock { + + override fun toString() = text + + override fun accept(visitor: BlockVisitor, ctx: C): R = visitor.visitText(this, ctx) + } + + public class Nest(val child: SqlBlock) : SqlBlock { + + override fun toString() = child.toString() + + override fun accept(visitor: BlockVisitor, ctx: C): R = visitor.visitNest(this, ctx) + } + + // Use link block rather than linked-list block.next as it makes pre-order traversal trivial + public class Link(val lhs: SqlBlock, val rhs: SqlBlock) : SqlBlock { + + override fun toString() = lhs.toString() + rhs.toString() + + override fun accept(visitor: BlockVisitor, ctx: C): R = visitor.visitLink(this, ctx) + } +} + +public interface BlockVisitor { + + public fun visit(block: SqlBlock, ctx: C): R + + public fun visitNil(block: SqlBlock.Nil, ctx: C): R + + public fun visitNewline(block: SqlBlock.NL, ctx: C): R + + public fun visitText(block: SqlBlock.Text, ctx: C): R + + public fun visitNest(block: SqlBlock.Nest, ctx: C): R + + public fun visitLink(block: SqlBlock.Link, ctx: C): R +} + +public abstract class BlockBaseVisitor : BlockVisitor { + + public abstract fun defaultReturn(block: SqlBlock, ctx: C): R + + public open fun defaultVisit(block: SqlBlock, ctx: C) = defaultReturn(block, ctx) + + public override fun visit(block: SqlBlock, ctx: C): R = block.accept(this, ctx) + + public override fun visitNil(block: SqlBlock.Nil, ctx: C): R = defaultVisit(block, ctx) + + public override fun visitNewline(block: SqlBlock.NL, ctx: C): R = defaultVisit(block, ctx) + + public override fun visitText(block: SqlBlock.Text, ctx: C): R = defaultVisit(block, ctx) + + public override fun visitNest(block: SqlBlock.Nest, ctx: C): R = defaultVisit(block, ctx) + + public override fun visitLink(block: SqlBlock.Link, ctx: C): R = defaultVisit(block, ctx) +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt new file mode 100644 index 000000000..b1a623a96 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt @@ -0,0 +1,766 @@ +package org.partiql.ast.sql + +import org.partiql.ast.AstNode +import org.partiql.ast.Expr +import org.partiql.ast.From +import org.partiql.ast.GroupBy +import org.partiql.ast.Identifier +import org.partiql.ast.Let +import org.partiql.ast.OrderBy +import org.partiql.ast.Path +import org.partiql.ast.Select +import org.partiql.ast.SetOp +import org.partiql.ast.SetQuantifier +import org.partiql.ast.Sort +import org.partiql.ast.Statement +import org.partiql.ast.Type +import org.partiql.ast.visitor.AstBaseVisitor +import org.partiql.value.MissingValue +import org.partiql.value.NullValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.TextValue +import org.partiql.value.io.PartiQLValueTextWriter +import java.io.ByteArrayOutputStream +import java.io.PrintStream + +/** + * SqlDialect represents the base behavior for transforming an [AstNode] tree into a [SqlBlock] tree. + */ +@Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") +public abstract class SqlDialect : AstBaseVisitor() { + + /** + * Default entry-point, can also be us. + */ + public fun apply(node: AstNode): SqlBlock = node.accept(this, SqlBlock.Nil) + + companion object { + + @JvmStatic + public val PARTIQL = object : SqlDialect() {} + } + + override fun defaultReturn(node: AstNode, head: SqlBlock) = throw UnsupportedOperationException("Cannot print $node") + + // STATEMENTS + + override fun visitStatementQuery(node: Statement.Query, head: SqlBlock) = visitExpr(node.expr, head) + + // IDENTIFIERS & PATHS + + override fun visitIdentifierSymbol(node: Identifier.Symbol, head: SqlBlock) = head concat r(node.sql()) + + override fun visitIdentifierQualified(node: Identifier.Qualified, head: SqlBlock): SqlBlock { + val path = node.steps.fold(node.root.sql()) { p, step -> p + "." + step.sql() } + return head concat r(path) + } + + override fun visitPath(node: Path, head: SqlBlock): SqlBlock { + val path = node.steps.fold(node.root.sql()) { p, step -> + when (step) { + is Path.Step.Index -> p + "[${step.index}]" + is Path.Step.Symbol -> p + "." + step.symbol.sql() + } + } + return head concat r(path) + } + + // cannot write path step outside the context of a path as we don't want it to reflow + override fun visitPathStep(node: Path.Step, head: SqlBlock) = error("path step cannot be written directly") + + override fun visitPathStepSymbol(node: Path.Step.Symbol, head: SqlBlock) = visitPathStep(node, head) + + override fun visitPathStepIndex(node: Path.Step.Index, head: SqlBlock) = visitPathStep(node, head) + + // TYPES + + override fun visitTypeNullType(node: Type.NullType, head: SqlBlock) = head concat r("NULL") + + override fun visitTypeMissing(node: Type.Missing, head: SqlBlock) = head concat r("MISSING") + + override fun visitTypeBool(node: Type.Bool, head: SqlBlock) = head concat r("BOOL") + + override fun visitTypeTinyint(node: Type.Tinyint, head: SqlBlock) = head concat r("TINYINT") + + override fun visitTypeSmallint(node: Type.Smallint, head: SqlBlock) = head concat r("SMALLINT") + + override fun visitTypeInt2(node: Type.Int2, head: SqlBlock) = head concat r("INT2") + + override fun visitTypeInt4(node: Type.Int4, head: SqlBlock) = head concat r("INT4") + + override fun visitTypeBigint(node: Type.Bigint, head: SqlBlock) = head concat r("BIGINT") + + override fun visitTypeInt8(node: Type.Int8, head: SqlBlock) = head concat r("INT8") + + override fun visitTypeInt(node: Type.Int, head: SqlBlock) = head concat r("INT") + + override fun visitTypeReal(node: Type.Real, head: SqlBlock) = head concat r("REAL") + + override fun visitTypeFloat32(node: Type.Float32, head: SqlBlock) = head concat r("FLOAT32") + + override fun visitTypeFloat64(node: Type.Float64, head: SqlBlock) = head concat r("DOUBLE PRECISION") + + override fun visitTypeDecimal(node: Type.Decimal, head: SqlBlock) = + head concat type("DECIMAL", node.precision, node.scale) + + override fun visitTypeNumeric(node: Type.Numeric, head: SqlBlock) = + head concat type("NUMERIC", node.precision, node.scale) + + override fun visitTypeChar(node: Type.Char, head: SqlBlock) = head concat type("CHAR", node.length) + + override fun visitTypeVarchar(node: Type.Varchar, head: SqlBlock) = head concat type("VARCHAR", node.length) + + override fun visitTypeString(node: Type.String, head: SqlBlock) = head concat r("STRING") + + override fun visitTypeSymbol(node: Type.Symbol, head: SqlBlock) = head concat r("SYMBOL") + + override fun visitTypeBit(node: Type.Bit, head: SqlBlock) = head concat type("BIT", node.length) + + override fun visitTypeBitVarying(node: Type.BitVarying, head: SqlBlock) = head concat type("BINARY", node.length) + + override fun visitTypeByteString(node: Type.ByteString, head: SqlBlock) = head concat type("BYTE", node.length) + + override fun visitTypeBlob(node: Type.Blob, head: SqlBlock) = head concat type("BLOB", node.length) + + override fun visitTypeClob(node: Type.Clob, head: SqlBlock) = head concat type("CLOB", node.length) + + override fun visitTypeBag(node: Type.Bag, head: SqlBlock) = head concat r("BAG") + + override fun visitTypeList(node: Type.List, head: SqlBlock) = head concat r("LIST") + + override fun visitTypeSexp(node: Type.Sexp, head: SqlBlock) = head concat r("SEXP") + + override fun visitTypeTuple(node: Type.Tuple, head: SqlBlock) = head concat r("TUPLE") + + override fun visitTypeStruct(node: Type.Struct, head: SqlBlock) = head concat r("STRUCT") + + override fun visitTypeAny(node: Type.Any, head: SqlBlock) = head concat r("ANY") + + override fun visitTypeDate(node: Type.Date, head: SqlBlock) = head concat r("DATE") + + override fun visitTypeTime(node: Type.Time, head: SqlBlock): SqlBlock = head concat type("TIME", node.precision) + + override fun visitTypeTimeWithTz(node: Type.TimeWithTz, head: SqlBlock) = + head concat type("TIME WITH TIMEZONE", node.precision, gap = true) + + override fun visitTypeTimestamp(node: Type.Timestamp, head: SqlBlock) = head concat type("TIMESTAMP", node.precision) + + override fun visitTypeTimestampWithTz(node: Type.TimestampWithTz, head: SqlBlock) = + head concat type("TIMESTAMP WITH TIMEZONE", node.precision, gap = true) + + override fun visitTypeInterval(node: Type.Interval, head: SqlBlock) = head concat type("INTERVAL", node.precision) + + // unsupported + override fun visitTypeCustom(node: Type.Custom, head: SqlBlock) = defaultReturn(node, head) + + // Expressions + + @OptIn(PartiQLValueExperimental::class) + override fun visitExprLit(node: Expr.Lit, head: SqlBlock): SqlBlock { + // Simplified PartiQL Value writing, as this intentionally omits formatting + val value = when (node.value) { + is MissingValue -> "MISSING" // force uppercase + is NullValue -> "NULL" // force uppercase + else -> { + val buffer = ByteArrayOutputStream() + val valueWriter = PartiQLValueTextWriter(PrintStream(buffer), false) + valueWriter.append(node.value) + buffer.toString() + } + } + return head concat r(value) + } + + override fun visitExprIon(node: Expr.Ion, head: SqlBlock): SqlBlock { + // simplified Ion value writing, as this intentionally omits formatting + val value = node.value.toString() + return head concat r("`$value`") + } + + override fun visitExprUnary(node: Expr.Unary, head: SqlBlock): SqlBlock { + val op = when (node.op) { + Expr.Unary.Op.NOT -> "NOT " + Expr.Unary.Op.POS -> "+" + Expr.Unary.Op.NEG -> "-" + } + var h = head + h = h concat r(op) + return visitExpr(node.expr, h) + } + + override fun visitExprBinary(node: Expr.Binary, head: SqlBlock): SqlBlock { + val op = when (node.op) { + Expr.Binary.Op.PLUS -> "+" + Expr.Binary.Op.MINUS -> "-" + Expr.Binary.Op.TIMES -> "*" + Expr.Binary.Op.DIVIDE -> "/" + Expr.Binary.Op.MODULO -> "%" + Expr.Binary.Op.CONCAT -> "||" + Expr.Binary.Op.AND -> "AND" + Expr.Binary.Op.OR -> "OR" + Expr.Binary.Op.EQ -> "=" + Expr.Binary.Op.NE -> "<>" + Expr.Binary.Op.GT -> ">" + Expr.Binary.Op.GTE -> ">=" + Expr.Binary.Op.LT -> "<" + Expr.Binary.Op.LTE -> "<=" + } + var h = head + h = visitExpr(node.lhs, h) + h = h concat r(" $op ") + h = visitExpr(node.rhs, h) + return h + } + + override fun visitExprVar(node: Expr.Var, head: SqlBlock): SqlBlock { + var h = head + // Prepend @ + if (node.scope == Expr.Var.Scope.LOCAL) { + h = h concat r("@") + } + h = visitIdentifier(node.identifier, h) + return h + } + + override fun visitExprSessionAttribute(node: Expr.SessionAttribute, head: SqlBlock) = + head concat r(node.attribute.name) + + override fun visitExprPath(node: Expr.Path, head: SqlBlock): SqlBlock { + var h = visitExpr(node.root, head) + h = node.steps.fold(h) { b, step -> visitExprPathStep(step, b) } + return h + } + + override fun visitExprPathStepSymbol(node: Expr.Path.Step.Symbol, head: SqlBlock) = + head concat r(".${node.symbol.sql()}") + + @OptIn(PartiQLValueExperimental::class) + override fun visitExprPathStepIndex(node: Expr.Path.Step.Index, head: SqlBlock): SqlBlock { + var h = head + val key = node.key + if (key is Expr.Lit && key.value is TextValue<*>) { + // use . syntax + h = h concat r(".") + h = h concat r((key.value as TextValue<*>).string!!) + } else { + // use [ ] syntax + h = h concat r("[") + h = visitExpr(node.key, h) + h = h concat r("]") + } + return h + } + + override fun visitExprPathStepWildcard(node: Expr.Path.Step.Wildcard, head: SqlBlock) = head concat r("[*]") + + override fun visitExprPathStepUnpivot(node: Expr.Path.Step.Unpivot, head: SqlBlock) = head concat r(".*") + + override fun visitExprCall(node: Expr.Call, head: SqlBlock): SqlBlock { + var h = head + h = visitIdentifier(node.function, h) + h = h concat list { node.args } + return h + } + + override fun visitExprAgg(node: Expr.Agg, head: SqlBlock): SqlBlock { + var h = head + val f = node.function + // Special case + if (f is Identifier.Symbol && f.symbol == "COUNT_STAR") { + return h concat r("COUNT(*)") + } + val start = if (node.setq != null) "(${node.setq!!.name} " else "(" + h = h concat visitIdentifier(f, h) + h = h concat list(start) { node.args } + return h + } + + override fun visitExprParameter(node: Expr.Parameter, head: SqlBlock) = head concat r("?") + + override fun visitExprValues(node: Expr.Values, head: SqlBlock) = head concat list("VALUES (") { node.rows } + + override fun visitExprValuesRow(node: Expr.Values.Row, head: SqlBlock) = head concat list { node.items } + + override fun visitExprCollection(node: Expr.Collection, head: SqlBlock): SqlBlock { + val (start, end) = when (node.type) { + Expr.Collection.Type.BAG -> "<<" to ">>" + Expr.Collection.Type.ARRAY -> "[" to "]" + Expr.Collection.Type.VALUES -> "VALUES (" to ")" + Expr.Collection.Type.LIST -> "(" to ")" + Expr.Collection.Type.SEXP -> "SEXP (" to ")" + } + return head concat list(start, end) { node.values } + } + + override fun visitExprStruct(node: Expr.Struct, head: SqlBlock) = head concat list("{", "}") { node.fields } + + override fun visitExprStructField(node: Expr.Struct.Field, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.name, h) + h = h concat r(": ") + h = visitExpr(node.value, h) + return h + } + + override fun visitExprLike(node: Expr.Like, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.value, h) + h = h concat if (node.not == true) r(" NOT LIKE ") else r(" LIKE ") + h = visitExpr(node.pattern, h) + if (node.escape != null) { + h = h concat r(" ESCAPE ") + h = visitExpr(node.escape!!, h) + } + return h + } + + override fun visitExprBetween(node: Expr.Between, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.value, h) + h = h concat if (node.not == true) r(" NOT BETWEEN ") else r(" BETWEEN ") + h = visitExpr(node.from, h) + h = h concat r(" AND ") + h = visitExpr(node.to, h) + return h + } + + override fun visitExprInCollection(node: Expr.InCollection, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.lhs, h) + h = h concat if (node.not == true) r(" NOT IN ") else r(" IN ") + h = visitExpr(node.rhs, h) + return h + } + + override fun visitExprIsType(node: Expr.IsType, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.value, h) + h = h concat if (node.not == true) r(" IS NOT ") else r(" IS ") + h = visitType(node.type, h) + return h + } + + override fun visitExprCase(node: Expr.Case, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("CASE") + h = when (node.expr) { + null -> h + else -> visitExpr(node.expr!!, h concat r(" ")) + } + // WHEN(s) + h = node.branches.fold(h) { acc, branch -> visitExprCaseBranch(branch, acc) } + // ELSE + h = when (node.default) { + null -> h + else -> { + h = h concat r(" ELSE ") + visitExpr(node.default!!, h) + } + } + h = h concat r(" END") + return h + } + + override fun visitExprCaseBranch(node: Expr.Case.Branch, head: SqlBlock): SqlBlock { + var h = head + h = h concat r(" WHEN ") + h = visitExpr(node.condition, h) + h = h concat r(" THEN ") + h = visitExpr(node.expr, h) + return h + } + + override fun visitExprCoalesce(node: Expr.Coalesce, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("COALESCE") + h = h concat list { node.args } + return h + } + + override fun visitExprNullIf(node: Expr.NullIf, head: SqlBlock): SqlBlock { + val args = listOf(node.value, node.nullifier) + var h = head + h = h concat r("NULLIF") + h = h concat list { args } + return h + } + + override fun visitExprSubstring(node: Expr.Substring, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("SUBSTRING(") + h = visitExpr(node.value, h) + if (node.start != null) { + h = h concat r(" FROM ") + h = visitExpr(node.start!!, h) + } + if (node.length != null) { + h = h concat r(" FOR ") + h = visitExpr(node.length!!, h) + } + h = h concat r(")") + return h + } + + override fun visitExprPosition(node: Expr.Position, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("POSITION(") + h = visitExpr(node.lhs, h) + h = h concat r(" IN ") + h = visitExpr(node.rhs, h) + h = h concat r(")") + return h + } + + override fun visitExprTrim(node: Expr.Trim, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("TRIM(") + // [LEADING|TRAILING|BOTH] + if (node.spec != null) { + h = h concat r("${node.spec!!.name} ") + } + // [ FROM] + if (node.chars != null) { + h = visitExpr(node.chars!!, h) + h = h concat r(" FROM ") + } + h = visitExpr(node.value, h) + h = h concat r(")") + return h + } + + override fun visitExprOverlay(node: Expr.Overlay, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("OVERLAY(") + h = visitExpr(node.value, h) + h = h concat r(" PLACING ") + h = visitExpr(node.overlay, h) + h = h concat r(" FROM ") + h = visitExpr(node.start, h) + if (node.length != null) { + h = h concat r(" FOR ") + h = visitExpr(node.length!!, h) + } + h = h concat r(")") + return h + } + + override fun visitExprExtract(node: Expr.Extract, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("EXTRACT(") + h = h concat r(node.field.name) + h = h concat r(" FROM ") + h = visitExpr(node.source, h) + h = h concat r(")") + return h + } + + override fun visitExprCast(node: Expr.Cast, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("CAST(") + h = visitExpr(node.value, h) + h = h concat r(" AS ") + h = visitType(node.asType, h) + h = h concat r(")") + return h + } + + override fun visitExprCanCast(node: Expr.CanCast, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("CAN_CAST(") + h = visitExpr(node.value, h) + h = h concat r(" AS ") + h = visitType(node.asType, h) + h = h concat r(")") + return h + } + + override fun visitExprCanLosslessCast(node: Expr.CanLosslessCast, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("CAN_LOSSLESS_CAST(") + h = visitExpr(node.value, h) + h = h concat r(" AS ") + h = visitType(node.asType, h) + h = h concat r(")") + return h + } + + override fun visitExprDateAdd(node: Expr.DateAdd, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("DATE_ADD(") + h = h concat r(node.field.name) + h = h concat r(", ") + h = visitExpr(node.lhs, h) + h = h concat r(", ") + h = visitExpr(node.rhs, h) + h = h concat r(")") + return h + } + + override fun visitExprDateDiff(node: Expr.DateDiff, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("DATE_DIFF(") + h = h concat r(node.field.name) + h = h concat r(", ") + h = visitExpr(node.lhs, h) + h = h concat r(", ") + h = visitExpr(node.rhs, h) + h = h concat r(")") + return h + } + + override fun visitExprBagOp(node: Expr.BagOp, head: SqlBlock): SqlBlock { + // [OUTER] [UNION|INTERSECT|EXCEPT] [ALL|DISTINCT] + val op = mutableListOf() + when (node.outer) { + true -> op.add("OUTER") + else -> {} + } + when (node.type.type) { + SetOp.Type.UNION -> op.add("UNION") + SetOp.Type.INTERSECT -> op.add("INTERSECT") + SetOp.Type.EXCEPT -> op.add("EXCEPT") + } + when (node.type.setq) { + SetQuantifier.ALL -> op.add("ALL") + SetQuantifier.DISTINCT -> op.add("DISTINCT") + null -> {} + } + var h = head + h = visitExpr(node.lhs, h) + h = h concat r(" ${op.joinToString(" ")} ") + h = visitExpr(node.rhs, h) + return h + } + + // SELECT-FROM-WHERE + + override fun visitExprSFW(node: Expr.SFW, head: SqlBlock): SqlBlock { + var h = head + // SELECT + h = visit(node.select, h) + // FROM + h = visit(node.from, h concat r(" FROM ")) + // LET + h = if (node.let != null) visitLet(node.let!!, h concat r(" ")) else h + // WHERE + h = if (node.where != null) visitExpr(node.where!!, h concat r(" WHERE ")) else h + // GROUP BY + h = if (node.groupBy != null) visitGroupBy(node.groupBy!!, h concat r(" ")) else h + // HAVING + h = if (node.having != null) visitExpr(node.having!!, h concat r(" HAVING ")) else h + // SET OP + h = if (node.setOp != null) visitExprSFWSetOp(node.setOp!!, h concat r(" ")) else h + // ORDER BY + h = if (node.orderBy != null) visitOrderBy(node.orderBy!!, h concat r(" ")) else h + // LIMIT + h = if (node.limit != null) visitExpr(node.limit!!, h concat r(" LIMIT ")) else h + // OFFSET + h = if (node.offset != null) visitExpr(node.offset!!, h concat r(" OFFSET ")) else h + return h + } + + // SELECT + + override fun visitSelectStar(node: Select.Star, head: SqlBlock): SqlBlock { + val select = when (node.setq) { + SetQuantifier.ALL -> "SELECT ALL *" + SetQuantifier.DISTINCT -> "SELECT DISTINCT *" + null -> "SELECT *" + } + return head concat r(select) + } + + override fun visitSelectProject(node: Select.Project, head: SqlBlock): SqlBlock { + val select = when (node.setq) { + SetQuantifier.ALL -> "SELECT ALL " + SetQuantifier.DISTINCT -> "SELECT DISTINCT " + null -> "SELECT " + } + return head concat list(select, "") { node.items } + } + + override fun visitSelectProjectItemAll(node: Select.Project.Item.All, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.expr, h) + h = h concat r(".*") + return h + } + + override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.expr, h) + h = if (node.asAlias != null) h concat r(" AS ${node.asAlias!!.sql()}") else h + return h + } + + override fun visitSelectPivot(node: Select.Pivot, head: SqlBlock): SqlBlock { + var h = head + h = h concat r("PIVOT ") + h = visitExpr(node.key, h) + h = h concat r(" AT ") + h = visitExpr(node.value, h) + return h + } + + override fun visitSelectValue(node: Select.Value, head: SqlBlock): SqlBlock { + val select = when (node.setq) { + SetQuantifier.ALL -> "SELECT ALL VALUE " + SetQuantifier.DISTINCT -> "SELECT DISTINCT VALUE " + null -> "SELECT VALUE " + } + var h = head + h = h concat r(select) + h = visitExpr(node.constructor, h) + return h + } + + // FROM + + override fun visitFromValue(node: From.Value, head: SqlBlock): SqlBlock { + var h = head + h = when (node.type) { + From.Value.Type.SCAN -> h + From.Value.Type.UNPIVOT -> h concat r("UNPIVOT ") + } + h = visitExpr(node.expr, h) + h = if (node.asAlias != null) h concat r(" AS ${node.asAlias!!.sql()}") else h + h = if (node.atAlias != null) h concat r(" AT ${node.atAlias!!.sql()}") else h + h = if (node.byAlias != null) h concat r(" BY ${node.byAlias!!.sql()}") else h + return h + } + + override fun visitFromJoin(node: From.Join, head: SqlBlock): SqlBlock { + var h = head + h = visitFrom(node.lhs, h) + h = h concat when (node.type) { + From.Join.Type.INNER -> r(" INNER JOIN ") + From.Join.Type.LEFT -> r(" LEFT JOIN ") + From.Join.Type.LEFT_OUTER -> r(" LEFT OUTER JOIN ") + From.Join.Type.RIGHT -> r(" RIGHT JOIN ") + From.Join.Type.RIGHT_OUTER -> r(" RIGHT OUTER JOIN ") + From.Join.Type.FULL -> r(" FULL JOIN ") + From.Join.Type.FULL_OUTER -> r(" FULL OUTER JOIN ") + From.Join.Type.CROSS -> r(" CROSS JOIN ") + From.Join.Type.COMMA -> r(", ") + null -> r(" JOIN ") + } + h = visitFrom(node.rhs, h) + h = if (node.condition != null) visit(node.condition!!, h concat r(" ON ")) else h + return h + } + + // LET + + override fun visitLet(node: Let, head: SqlBlock) = head concat list("LET ", "") { node.bindings } + + override fun visitLetBinding(node: Let.Binding, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.expr, h) + h = h concat r(" AS ${node.asAlias.sql()}") + return h + } + + // GROUP BY + + override fun visitGroupBy(node: GroupBy, head: SqlBlock): SqlBlock { + var h = head + h = h concat when (node.strategy) { + GroupBy.Strategy.FULL -> r("GROUP BY ") + GroupBy.Strategy.PARTIAL -> r("GROUP PARTIAL BY ") + } + h = h concat list("", "") { node.keys } + h = if (node.asAlias != null) h concat r(" GROUP AS ${node.asAlias!!.sql()}") else h + return h + } + + override fun visitGroupByKey(node: GroupBy.Key, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.expr, h) + h = if (node.asAlias != null) h concat r(" AS ${node.asAlias!!.sql()}") else h + return h + } + + // SET OPERATORS + + override fun visitSetOp(node: SetOp, head: SqlBlock): SqlBlock { + val op = when (node.setq) { + null -> node.type.name + else -> "${node.type.name} ${node.setq!!.name}" + } + return head concat r(op) + } + + override fun visitExprSFWSetOp(node: Expr.SFW.SetOp, head: SqlBlock): SqlBlock { + var h = head + h = visitSetOp(node.type, h) + h = h concat r(" ") + h = h concat r("(") + val subquery = visitExprSFW(node.operand, SqlBlock.Nil) + h = h concat SqlBlock.Nest(subquery) + h = h concat r(")") + return h + } + + // ORDER BY + + override fun visitOrderBy(node: OrderBy, head: SqlBlock) = head concat list("ORDER BY ", "") { node.sorts } + + override fun visitSort(node: Sort, head: SqlBlock): SqlBlock { + var h = head + h = visitExpr(node.expr, h) + h = when (node.dir) { + Sort.Dir.ASC -> h concat r(" ASC") + Sort.Dir.DESC -> h concat r(" DESC") + null -> h + } + h = when (node.nulls) { + Sort.Nulls.FIRST -> h concat r(" NULLS FIRST") + Sort.Nulls.LAST -> h concat r(" NULLS LAST") + null -> h + } + return h + } + + // --- Block Constructor Helpers + + private fun type(symbol: String, vararg args: Int?, gap: Boolean = false): SqlBlock { + val p = args.filterNotNull() + val t = when { + p.isEmpty() -> symbol + else -> { + val a = p.joinToString(",") + when (gap) { + true -> "$symbol ($a)" + else -> "$symbol($a)" + } + } + } + // types are modeled as text; as we don't way to reflow + return r(t) + } + + // > infix fun Block.concat(rhs: String): SqlBlock.Link = Block.Link(this, Block.Raw(rhs)) + // > head concat "foo" + private fun r(text: String): SqlBlock = SqlBlock.Text(text) + + private fun list( + start: String? = "(", + end: String? = ")", + delimiter: String? = ", ", + children: () -> List, + ): SqlBlock { + val kids = children() + var h = start?.let { r(it) } ?: SqlBlock.Nil + kids.forEachIndexed { i, child -> + h = child.accept(this, h) + h = if (delimiter != null && (i + 1) < kids.size) h concat r(delimiter) else h + } + h = if (end != null) h concat r(end) else h + return h + } + + private fun Identifier.Symbol.sql() = when (caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> "\"$symbol\"" + Identifier.CaseSensitivity.INSENSITIVE -> symbol // verbatim .. + } +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlLayout.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlLayout.kt new file mode 100644 index 000000000..4bfd1dae4 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlLayout.kt @@ -0,0 +1,96 @@ +package org.partiql.ast.sql + +/** + * [SqlLayout] determines how an [SqlBlock] tree is transformed in SQL text. + */ +public abstract class SqlLayout { + + abstract val indent: Indent + + public open fun format(root: SqlBlock): String { + val ctx = Ctx.empty() + root.accept(Formatter(), ctx) + return ctx.toString() + } + + public companion object { + + /** + * Default SQL format. + */ + public val DEFAULT = object : SqlLayout() { + + override val indent = Indent(2, Indent.Type.SPACE) + } + + /** + * Write SQL statement on one line. + */ + public val ONELINE = object : SqlLayout() { + + override val indent = Indent(2, Indent.Type.SPACE) + + override fun format(root: SqlBlock): String = root.toString().replace("\n", "") + } + } + + /** + * [SqlLayout] indent configuration. + * + * @property count + * @property type + */ + public class Indent( + public val count: Int, + public val type: Type, + ) { + + enum class Type(val char: Char) { + TAB(Char(9)), + SPACE(Char(32)), ; + } + + override fun toString() = type.char.toString().repeat(count) + } + + private class Ctx private constructor(val out: StringBuilder, val level: Int) { + fun nest() = Ctx(out, level + 1) + + override fun toString() = out.toString() + + companion object { + fun empty() = Ctx(StringBuilder(), 0) + } + } + + private inner class Formatter : BlockBaseVisitor() { + + private inline fun write(ctx: Ctx, f: () -> String) { + if (ctx.level > 0) ctx.out.append(lead(ctx)) + ctx.out.append(f()) + } + + override fun defaultReturn(block: SqlBlock, ctx: Ctx) = write(ctx) { + block.toString() + } + + override fun visitNil(block: SqlBlock.Nil, ctx: Ctx) {} + + override fun visitNewline(block: SqlBlock.NL, ctx: Ctx) { + ctx.out.appendLine() + } + + override fun visitText(block: SqlBlock.Text, ctx: Ctx) = write(ctx) { block.text } + + override fun visitNest(block: SqlBlock.Nest, ctx: Ctx) { + block.child.accept(this, ctx.nest()) + } + + override fun visitLink(block: SqlBlock.Link, ctx: Ctx) { + block.lhs.accept(this, ctx) + block.rhs.accept(this, ctx) + } + + private fun lead(ctx: Ctx) = indent.toString().repeat(ctx.level) + } +} diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlBlockWriterTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlBlockWriterTest.kt new file mode 100644 index 000000000..523e72b3e --- /dev/null +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlBlockWriterTest.kt @@ -0,0 +1,77 @@ +package org.partiql.ast.sql + +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource + +class SqlBlockWriterTest { + + @ParameterizedTest(name = "write #{index}") + @MethodSource("onelineCases") + @Execution(ExecutionMode.CONCURRENT) + fun write(case: Case) = case.assert() + + @ParameterizedTest(name = "format #{index}") + @MethodSource("formatCases") + @Execution(ExecutionMode.CONCURRENT) + fun format(case: Case) = case.assert() + + companion object { + + private fun block(): SqlBlock { + + return NIL + "aaa[" + NL + nest { + NIL + "bbbbb[" + NL + nest { + NIL + "ccc," + NL + "dd" + NL + } + "]," + NL + "eee," + NL + "ffff[" + NL + nest { + NIL + "gg," + NL + "hhh," + NL + "ii" + NL + } + "]" + NL + } + "]" + } + + @JvmStatic + fun onelineCases() = listOf( + oneline("aaa[bbbbb[ccc,dd],eee,ffff[gg,hhh,ii]]") { block() }, + ) + + @JvmStatic + fun formatCases() = listOf( + format( + """ + |aaa[ + | bbbbb[ + | ccc, + | dd + | ], + | eee, + | ffff[ + | gg, + | hhh, + | ii + | ] + |] + """.trimMargin() + ) { block() } + ) + + private fun format(expected: String, block: () -> SqlBlock) = Case(block(), expected, SqlLayout.DEFAULT::format) + + private fun oneline(expected: String, block: () -> SqlBlock) = Case(block(), expected, SqlLayout.ONELINE::format) + + private fun r(text: String) = SqlBlock.Text(text) + } + + class Case( + private val input: SqlBlock, + private val expected: String, + private val action: (SqlBlock) -> String, + ) { + + fun assert() { + val actual = action(input) + Assertions.assertEquals(expected, actual) + } + } +} diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt new file mode 100644 index 000000000..7f9b320ac --- /dev/null +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt @@ -0,0 +1,1656 @@ +package org.partiql.ast.sql + +import com.amazon.ion.Decimal +import com.amazon.ionelement.api.ionBool +import com.amazon.ionelement.api.ionDecimal +import com.amazon.ionelement.api.ionFloat +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionNull +import com.amazon.ionelement.api.ionString +import com.amazon.ionelement.api.ionSymbol +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.partiql.ast.Ast +import org.partiql.ast.AstNode +import org.partiql.ast.DatetimeField +import org.partiql.ast.Expr +import org.partiql.ast.From +import org.partiql.ast.GroupBy +import org.partiql.ast.Identifier +import org.partiql.ast.SetOp +import org.partiql.ast.SetQuantifier +import org.partiql.ast.Sort +import org.partiql.ast.builder.AstBuilder +import org.partiql.ast.builder.AstFactory +import org.partiql.ast.builder.ast +import org.partiql.ast.sql +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.boolValue +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue +import org.partiql.value.missingValue +import org.partiql.value.nullValue +import org.partiql.value.stringValue +import org.partiql.value.symbolValue +import java.math.BigDecimal +import java.math.BigInteger +import kotlin.test.assertFails + +/** + * This tests the Ast to test via the base SqlDialect. + * + * It does NOT test formatted output. + */ +@OptIn(PartiQLValueExperimental::class) +class SqlDialectTest { + + // Identifiers & Paths + + @ParameterizedTest(name = "identifiers #{index}") + @MethodSource("identifiers") + @Execution(ExecutionMode.CONCURRENT) + fun testIdentifiers(case: Case) = case.assert() + + @ParameterizedTest(name = "paths #{index}") + @MethodSource("paths") + @Execution(ExecutionMode.CONCURRENT) + fun testPaths(case: Case) = case.assert() + + // Types + + @ParameterizedTest(name = "types #{index}") + @MethodSource("types") + @Execution(ExecutionMode.CONCURRENT) + fun testTypes(case: Case) = case.assert() + + // Expressions + + @ParameterizedTest(name = "expr.lit #{index}") + @MethodSource("exprLitCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprLit(case: Case) = case.assert() + + @ParameterizedTest(name = "expr.ion #{index}") + @MethodSource("exprIonCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprIon(case: Case) = case.assert() + + @ParameterizedTest(name = "expr.var #{index}") + @MethodSource("exprVarCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprVar(case: Case) = case.assert() + + @ParameterizedTest(name = "expr.path #{index}") + @MethodSource("exprPathCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprPath(case: Case) = case.assert() + + @ParameterizedTest() + @MethodSource("exprOperators") + @Execution(ExecutionMode.CONCURRENT) + fun testExprOperators(case: Case) = case.assert() + + @ParameterizedTest(name = "expr.call #{index}") + @MethodSource("exprCallCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprCall(case: Case) = case.assert() + + @ParameterizedTest(name = "expr.agg #{index}") + @MethodSource("exprAggCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprAgg(case: Case) = case.assert() + + @ParameterizedTest(name = "expr.collection #{index}") + @MethodSource("exprCollectionCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprCollection(case: Case) = case.assert() + + @ParameterizedTest(name = "expr.struct #{index}") + @MethodSource("exprStructCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprStruct(case: Case) = case.assert() + + @ParameterizedTest(name = "special form #{index}") + @MethodSource("exprSpecialFormCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprSpecialForm(case: Case) = case.assert() + + @ParameterizedTest(name = "expr.case #{index}") + @MethodSource("exprCaseCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExprCase(case: Case) = case.assert() + + // SELECT-FROM-WHERE + + @ParameterizedTest(name = "SELECT Clause #{index}") + @MethodSource("selectClauseCases") + @Execution(ExecutionMode.CONCURRENT) + fun testSelectClause(case: Case) = case.assert() + + @ParameterizedTest(name = "FROM Clause #{index}") + @MethodSource("fromClauseCases") + @Execution(ExecutionMode.CONCURRENT) + fun testFromClause(case: Case) = case.assert() + + @ParameterizedTest(name = "JOIN Clause #{index}") + @MethodSource("joinClauseCases") + @Execution(ExecutionMode.CONCURRENT) + fun testJoinClause(case: Case) = case.assert() + + @ParameterizedTest(name = "GROUP BY Clause #{index}") + @MethodSource("groupByClauseCases") + @Execution(ExecutionMode.CONCURRENT) + fun testGroupByClause(case: Case) = case.assert() + + @ParameterizedTest(name = "UNION Clause #{index}") + @MethodSource("unionClauseCases") + @Execution(ExecutionMode.CONCURRENT) + fun testUnionClause(case: Case) = case.assert() + + @ParameterizedTest(name = "ORDER BY Clause #{index}") + @MethodSource("orderByClauseCases") + @Execution(ExecutionMode.CONCURRENT) + fun testOrderByClause(case: Case) = case.assert() + + @ParameterizedTest(name = "other clauses #{index}") + @MethodSource("otherClausesCases") + @Execution(ExecutionMode.CONCURRENT) + fun testOtherClauses(case: Case) = case.assert() + + companion object { + + private val NULL = Ast.exprLit(nullValue()) + + @JvmStatic + fun types() = listOf( + // SQL + expect("NULL") { typeNullType() }, + expect("BOOL") { typeBool() }, + expect("SMALLINT") { typeSmallint() }, + expect("INT") { typeInt() }, + expect("REAL") { typeReal() }, + expect("FLOAT32") { typeFloat32() }, + expect("DOUBLE PRECISION") { typeFloat64() }, + expect("DECIMAL") { typeDecimal() }, + expect("DECIMAL(2)") { typeDecimal(2) }, + expect("DECIMAL(2,1)") { typeDecimal(2, 1) }, + expect("NUMERIC") { typeNumeric() }, + expect("NUMERIC(2)") { typeNumeric(2) }, + expect("NUMERIC(2,1)") { typeNumeric(2, 1) }, + expect("TIMESTAMP") { typeTimestamp() }, + expect("CHAR") { typeChar() }, + expect("CHAR(1)") { typeChar(1) }, + expect("VARCHAR") { typeVarchar() }, + expect("VARCHAR(1)") { typeVarchar(1) }, + expect("BLOB") { typeBlob() }, + expect("CLOB") { typeClob() }, + expect("DATE") { typeDate() }, + expect("TIME") { typeTime() }, + expect("TIME(1)") { typeTime(1) }, + expect("TIME WITH TIMEZONE") { typeTimeWithTz() }, + expect("TIME WITH TIMEZONE (1)") { typeTimeWithTz(1) }, + // TODO TIMESTAMP + // TODO INTERVAL + // PartiQL + expect("MISSING") { typeMissing() }, + expect("STRING") { typeString() }, + expect("SYMBOL") { typeSymbol() }, + expect("STRUCT") { typeStruct() }, + expect("TUPLE") { typeTuple() }, + expect("LIST") { typeList() }, + expect("SEXP") { typeSexp() }, + expect("BAG") { typeBag() }, + expect("ANY") { typeAny() }, + // Other (??) + expect("INT4") { typeInt4() }, + expect("INT8") { typeInt8() }, + // + fail("PartiQLDialect does not support custom types") { typeCustom("foo") }, + ) + + @JvmStatic + fun exprOperators() = listOf( + expect("NOT NULL") { + exprUnary { + op = Expr.Unary.Op.NOT + expr = NULL + } + }, + expect("+NULL") { + exprUnary { + op = Expr.Unary.Op.POS + expr = NULL + } + }, + expect("-NULL") { + exprUnary { + op = Expr.Unary.Op.NEG + expr = NULL + } + }, + expect("NULL + NULL") { + exprBinary { + op = Expr.Binary.Op.PLUS + lhs = NULL + rhs = NULL + } + }, + ) + + @JvmStatic + fun identifiers() = listOf( + expect("x") { + id("x") + }, + expect("X") { + id("X") + }, + expect("\"x\"") { + id("x", Identifier.CaseSensitivity.SENSITIVE) + }, + expect("x.y.z") { + identifierQualified { + root = id("x") + steps += id("y") + steps += id("z") + } + }, + expect("x.\"y\".z") { + identifierQualified { + root = id("x") + steps += id("y", Identifier.CaseSensitivity.SENSITIVE) + steps += id("z") + } + }, + expect("\"x\".\"y\".\"z\"") { + identifierQualified { + root = id("x", Identifier.CaseSensitivity.SENSITIVE) + steps += id("y", Identifier.CaseSensitivity.SENSITIVE) + steps += id("z", Identifier.CaseSensitivity.SENSITIVE) + } + }, + ) + + @JvmStatic + fun paths() = listOf( + expect("x.y.z") { + path { + root = id("x") + steps += pathStepSymbol(id("y")) + steps += pathStepSymbol(id("z")) + } + }, + expect("x.y[0]") { + path { + root = id("x") + steps += pathStepSymbol(id("y")) + steps += pathStepIndex(0) + } + }, + expect("x[0].y") { + path { + root = id("x") + steps += pathStepIndex(0) + steps += pathStepSymbol(id("y")) + } + }, + expect("\"x\".\"y\".\"z\"") { + path { + root = id("x", Identifier.CaseSensitivity.SENSITIVE) + steps += pathStepSymbol(id("y", Identifier.CaseSensitivity.SENSITIVE)) + steps += pathStepSymbol(id("z", Identifier.CaseSensitivity.SENSITIVE)) + } + }, + ) + + // Expressions + + @JvmStatic + fun exprLitCases() = listOf( + expect("NULL") { + exprLit(nullValue()) + }, + expect("MISSING") { + exprLit(missingValue()) + }, + expect("true") { + exprLit(boolValue(true)) + }, + expect("1") { + exprLit(int8Value(1)) + }, + expect("2") { + exprLit(int16Value(2)) + }, + expect("3") { + exprLit(int32Value(3)) + }, + expect("4") { + exprLit(int64Value(4)) + }, + expect("5") { + exprLit(intValue(BigInteger.valueOf(5))) + }, + // TODO fix PartiQL Text writer for floats + // expect("1.1e0") { + expect("1.1") { + exprLit(float32Value(1.1f)) + }, + // TODO fix PartiQL Text writer for floats + // expect("1.2e0") { + expect("1.2") { + exprLit(float64Value(1.2)) + }, + expect("1.3") { + exprLit(decimalValue(BigDecimal.valueOf(1.3))) + }, + expect("""'hello'""") { + exprLit(stringValue("hello")) + }, + expect("""hello""") { + exprLit(symbolValue("hello")) + }, + // expect("""{{ '''Hello''' '''World''' }}""") { + // exprLit(clobValue("HelloWorld".toByteArray())) + // }, + // expect("""{{ VG8gaW5maW5pdHkuLi4gYW5kIGJleW9uZCE= }}""") { + // exprLit(blobValue("To infinity... and beyond!".toByteArray())) + // }, + ) + + @JvmStatic + fun exprIonCases() = listOf( + expect("`null`") { + exprIon(ionNull()) + }, + expect("`true`") { + exprIon(ionBool(true)) + }, + expect("`1`") { + exprIon(ionInt(1)) + }, + expect("`1.2e0`") { + exprIon(ionFloat(1.2)) + }, + expect("`1.3`") { + exprIon(ionDecimal(Decimal.valueOf(1.3))) + }, + expect("""`"hello"`""") { + exprIon(ionString("hello")) + }, + expect("""`hello`""") { + exprIon(ionSymbol("hello")) + }, + ) + + @JvmStatic + fun exprVarCases() = listOf( + // DEFAULT + expect("x") { + val id = id("x") + exprVar(id, Expr.Var.Scope.DEFAULT) + }, + expect("\"x\"") { + val id = id("x", Identifier.CaseSensitivity.SENSITIVE) + exprVar(id, Expr.Var.Scope.DEFAULT) + }, + expect("x.y.z") { + val id = identifierQualified { + root = id("x") + steps += id("y") + steps += id("z") + } + exprVar(id, Expr.Var.Scope.DEFAULT) + }, + expect("x.\"y\".z") { + val id = identifierQualified { + root = id("x") + steps += id("y", Identifier.CaseSensitivity.SENSITIVE) + steps += id("z") + } + exprVar(id, Expr.Var.Scope.DEFAULT) + }, + expect("\"x\".\"y\".\"z\"") { + val id = identifierQualified { + root = id("x", Identifier.CaseSensitivity.SENSITIVE) + steps += id("y", Identifier.CaseSensitivity.SENSITIVE) + steps += id("z", Identifier.CaseSensitivity.SENSITIVE) + } + exprVar(id, Expr.Var.Scope.DEFAULT) + }, + // LOCAL + expect("@x") { + val id = id("x") + exprVar(id, Expr.Var.Scope.LOCAL) + }, + expect("@\"x\"") { + val id = id("x", Identifier.CaseSensitivity.SENSITIVE) + exprVar(id, Expr.Var.Scope.LOCAL) + }, + expect("@x.y.z") { + val id = identifierQualified { + root = id("x") + steps += id("y") + steps += id("z") + } + exprVar(id, Expr.Var.Scope.LOCAL) + }, + expect("@x.\"y\".z") { + val id = identifierQualified { + root = id("x") + steps += id("y", Identifier.CaseSensitivity.SENSITIVE) + steps += id("z") + } + exprVar(id, Expr.Var.Scope.LOCAL) + }, + expect("@\"x\".\"y\".\"z\"") { + val id = identifierQualified { + root = id("x", Identifier.CaseSensitivity.SENSITIVE) + steps += id("y", Identifier.CaseSensitivity.SENSITIVE) + steps += id("z", Identifier.CaseSensitivity.SENSITIVE) + } + exprVar(id, Expr.Var.Scope.LOCAL) + }, + ) + + @JvmStatic + fun exprPathCases() = listOf( + expect("x.y.*") { + exprPath { + root = exprVar { + identifier = id("x") + scope = Expr.Var.Scope.DEFAULT + } + steps += exprPathStepSymbol(id("y")) + steps += exprPathStepUnpivot() + } + }, + expect("x.y[*]") { + exprPath { + root = exprVar { + identifier = id("x") + scope = Expr.Var.Scope.DEFAULT + } + steps += exprPathStepSymbol(id("y")) + steps += exprPathStepWildcard() + } + }, + expect("x[1 + a]") { + exprPath { + root = exprVar { + identifier = id("x") + scope = Expr.Var.Scope.DEFAULT + } + steps += exprPathStepIndex( + exprBinary { + op = Expr.Binary.Op.PLUS + lhs = exprLit(int32Value(1)) + rhs = exprVar { + identifier = id("a") + scope = Expr.Var.Scope.DEFAULT + } + } + ) + } + }, + ) + + @JvmStatic + fun exprCallCases() = listOf( + expect("foo(1)") { + exprCall { + function = id("foo") + args += exprLit(int32Value(1)) + } + }, + expect("foo(1, 2)") { + exprCall { + function = id("foo") + args += exprLit(int32Value(1)) + args += exprLit(int32Value(2)) + } + }, + expect("foo.bar(1)") { + exprCall { + function = identifierQualified { + root = id("foo") + steps += id("bar") + } + args += exprLit(int32Value(1)) + } + }, + expect("foo.bar(1, 2)") { + exprCall { + function = identifierQualified { + root = id("foo") + steps += id("bar") + } + args += exprLit(int32Value(1)) + args += exprLit(int32Value(2)) + } + }, + ) + + @JvmStatic + fun exprAggCases() = listOf( + expect("FOO(x)") { + exprAgg { + function = id("FOO") + args += exprVar(id("x"), Expr.Var.Scope.DEFAULT) + } + }, + expect("FOO(ALL x)") { + exprAgg { + function = id("FOO") + setq = SetQuantifier.ALL + args += exprVar(id("x"), Expr.Var.Scope.DEFAULT) + } + }, + expect("FOO(DISTINCT x)") { + exprAgg { + function = id("FOO") + setq = SetQuantifier.DISTINCT + args += exprVar(id("x"), Expr.Var.Scope.DEFAULT) + } + }, + expect("FOO(x, y)") { + exprAgg { + function = id("FOO") + args += exprVar(id("x"), Expr.Var.Scope.DEFAULT) + args += exprVar(id("y"), Expr.Var.Scope.DEFAULT) + } + }, + expect("FOO(ALL x, y)") { + exprAgg { + function = id("FOO") + setq = SetQuantifier.ALL + args += exprVar(id("x"), Expr.Var.Scope.DEFAULT) + args += exprVar(id("y"), Expr.Var.Scope.DEFAULT) + } + }, + expect("FOO(DISTINCT x, y)") { + exprAgg { + function = id("FOO") + setq = SetQuantifier.DISTINCT + args += exprVar(id("x"), Expr.Var.Scope.DEFAULT) + args += exprVar(id("y"), Expr.Var.Scope.DEFAULT) + } + }, + expect("COUNT(*)") { + exprAgg { + function = id("COUNT_STAR") + } + } + ) + + @JvmStatic + fun exprCollectionCases() = listOf( + expect("<<>>") { + exprCollection { + type = Expr.Collection.Type.BAG + } + }, + expect("<<1, 2, 3>>") { + exprCollection { + type = Expr.Collection.Type.BAG + values += exprLit(int32Value(1)) + values += exprLit(int32Value(2)) + values += exprLit(int32Value(3)) + } + }, + expect("[]") { + exprCollection { + type = Expr.Collection.Type.ARRAY + } + }, + expect("[1, 2, 3]") { + exprCollection { + type = Expr.Collection.Type.ARRAY + values += exprLit(int32Value(1)) + values += exprLit(int32Value(2)) + values += exprLit(int32Value(3)) + } + }, + expect("VALUES ()") { + exprCollection { + type = Expr.Collection.Type.VALUES + } + }, + expect("VALUES (1, 2, 3)") { + exprCollection { + type = Expr.Collection.Type.VALUES + values += exprLit(int32Value(1)) + values += exprLit(int32Value(2)) + values += exprLit(int32Value(3)) + } + }, + expect("()") { + exprCollection { + type = Expr.Collection.Type.LIST + } + }, + expect("(1, 2, 3)") { + exprCollection { + type = Expr.Collection.Type.LIST + values += exprLit(int32Value(1)) + values += exprLit(int32Value(2)) + values += exprLit(int32Value(3)) + } + }, + expect("SEXP ()") { + exprCollection { + type = Expr.Collection.Type.SEXP + } + }, + expect("SEXP (1, 2, 3)") { + exprCollection { + type = Expr.Collection.Type.SEXP + values += exprLit(int32Value(1)) + values += exprLit(int32Value(2)) + values += exprLit(int32Value(3)) + } + }, + ) + + @JvmStatic + fun exprStructCases() = listOf( + expect("{}") { + exprStruct() + }, + expect("{a: 1}") { + exprStruct { + fields += exprStructField { + name = exprLit(symbolValue("a")) + value = exprLit(int32Value(1)) + } + } + }, + expect("{a: 1, b: false}") { + exprStruct { + fields += exprStructField { + name = exprLit(symbolValue("a")) + value = exprLit(int32Value(1)) + } + fields += exprStructField { + name = exprLit(symbolValue("b")) + value = exprLit(boolValue(false)) + } + } + }, + ) + + @JvmStatic + fun exprSpecialFormCases() = listOf( + expect("x LIKE y") { + exprLike { + value = v("x") + pattern = v("y") + } + }, + expect("x NOT LIKE y") { + exprLike { + value = v("x") + pattern = v("y") + not = true + } + }, + expect("x LIKE y ESCAPE z") { + exprLike { + value = v("x") + pattern = v("y") + escape = v("z") + } + }, + expect("x BETWEEN y AND z") { + exprBetween { + value = v("x") + from = v("y") + to = v("z") + } + }, + expect("x NOT BETWEEN y AND z") { + exprBetween { + value = v("x") + from = v("y") + to = v("z") + not = true + } + }, + expect("x IN y") { + exprInCollection { + lhs = v("x") + rhs = v("y") + } + }, + expect("x NOT IN y") { + exprInCollection { + lhs = v("x") + rhs = v("y") + not = true + } + }, + expect("x IS BOOL") { + exprIsType { + value = v("x") + type = typeBool() + } + }, + expect("x IS NOT BOOL") { + exprIsType { + value = v("x") + type = typeBool() + not = true + } + }, + expect("NULLIF(x, y)") { + exprNullIf { + value = v("x") + nullifier = v("y") + } + }, + expect("COALESCE(x, y, z)") { + exprCoalesce { + args += v("x") + args += v("y") + args += v("z") + } + }, + expect("SUBSTRING(x)") { + exprSubstring { + value = v("x") + } + }, + expect("SUBSTRING(x FROM i)") { + exprSubstring { + value = v("x") + start = v("i") + } + }, + expect("SUBSTRING(x FROM i FOR n)") { + exprSubstring { + value = v("x") + start = v("i") + length = v("n") + } + }, + expect("SUBSTRING(x FOR n)") { + exprSubstring { + value = v("x") + length = v("n") + } + }, + expect("POSITION(x IN y)") { + exprPosition { + lhs = v("x") + rhs = v("y") + } + }, + expect("TRIM(x)") { + exprTrim { + value = v("x") + } + }, + expect("TRIM(BOTH x)") { + exprTrim { + value = v("x") + spec = Expr.Trim.Spec.BOTH + } + }, + expect("TRIM(LEADING y FROM x)") { + exprTrim { + value = v("x") + spec = Expr.Trim.Spec.LEADING + chars = v("y") + } + }, + expect("TRIM(y FROM x)") { + exprTrim { + value = v("x") + chars = v("y") + } + }, + expect("OVERLAY(x PLACING y FROM z)") { + exprOverlay { + value = v("x") + overlay = v("y") + start = v("z") + } + }, + expect("OVERLAY(x PLACING y FROM z FOR n)") { + exprOverlay { + value = v("x") + overlay = v("y") + start = v("z") + length = v("n") + } + }, + expect("EXTRACT(MINUTE FROM x)") { + exprExtract { + field = DatetimeField.MINUTE + source = v("x") + } + }, + expect("CAST(x AS INT)") { + exprCast { + value = v("x") + asType = typeInt() + } + }, + expect("CAN_CAST(x AS INT)") { + exprCanCast { + value = v("x") + asType = typeInt() + } + }, + expect("CAN_LOSSLESS_CAST(x AS INT)") { + exprCanLosslessCast { + value = v("x") + asType = typeInt() + } + }, + expect("DATE_ADD(MINUTE, x, y)") { + exprDateAdd { + field = DatetimeField.MINUTE + lhs = v("x") + rhs = v("y") + } + }, + expect("DATE_DIFF(MINUTE, x, y)") { + exprDateDiff { + field = DatetimeField.MINUTE + lhs = v("x") + rhs = v("y") + } + }, + expect("x UNION y") { + exprBagOp { + type = setOp { + type = SetOp.Type.UNION + setq = null + } + outer = false + lhs = v("x") + rhs = v("y") + } + }, + expect("x UNION ALL y") { + exprBagOp { + type = setOp { + type = SetOp.Type.UNION + setq = SetQuantifier.ALL + } + outer = false + lhs = v("x") + rhs = v("y") + } + }, + expect("x OUTER UNION y") { + exprBagOp { + type = setOp { + type = SetOp.Type.UNION + setq = null + } + outer = true + lhs = v("x") + rhs = v("y") + } + }, + expect("x OUTER UNION ALL y") { + exprBagOp { + type = setOp { + type = SetOp.Type.UNION + setq = SetQuantifier.ALL + } + outer = true + lhs = v("x") + rhs = v("y") + } + }, + ) + + @JvmStatic + fun exprCaseCases() = listOf( + expect("CASE WHEN a THEN x WHEN b THEN y END") { + exprCase { + branches += exprCaseBranch(v("a"), v("x")) + branches += exprCaseBranch(v("b"), v("y")) + } + }, + expect("CASE z WHEN a THEN x WHEN b THEN y END") { + exprCase { + expr = v("z") + branches += exprCaseBranch(v("a"), v("x")) + branches += exprCaseBranch(v("b"), v("y")) + } + }, + expect("CASE z WHEN a THEN x ELSE y END") { + exprCase { + expr = v("z") + branches += exprCaseBranch(v("a"), v("x")) + default = v("y") + } + }, + ) + + @JvmStatic + fun selectClauseCases() = listOf( + expect("SELECT a FROM T") { + exprSFW { + select = selectProject { + items += selectProjectItemExpression(v("a")) + } + from = table("T") + } + }, + expect("SELECT a AS x FROM T") { + exprSFW { + select = selectProject { + items += selectProjectItemExpression(v("a"), id("x")) + } + from = table("T") + } + }, + expect("SELECT a AS x, b AS y FROM T") { + exprSFW { + select = selectProject { + items += selectProjectItemExpression(v("a"), id("x")) + items += selectProjectItemExpression(v("b"), id("y")) + } + from = table("T") + } + }, + expect("SELECT ALL a FROM T") { + exprSFW { + select = selectProject { + setq = SetQuantifier.ALL + items += selectProjectItemExpression(v("a")) + } + from = table("T") + } + }, + expect("SELECT DISTINCT a FROM T") { + exprSFW { + select = selectProject { + setq = SetQuantifier.DISTINCT + items += selectProjectItemExpression(v("a")) + } + from = table("T") + } + }, + expect("SELECT a.* FROM T") { + exprSFW { + select = selectProject { + items += selectProjectItemAll(v("a")) + } + from = table("T") + } + }, + expect("SELECT * FROM T") { + exprSFW { + select = selectStar() + from = table("T") + } + }, + expect("SELECT DISTINCT * FROM T") { + exprSFW { + select = selectStar(SetQuantifier.DISTINCT) + from = table("T") + } + }, + expect("SELECT ALL * FROM T") { + exprSFW { + select = selectStar(SetQuantifier.ALL) + from = table("T") + } + }, + expect("SELECT VALUE a FROM T") { + exprSFW { + select = selectValue { + constructor = v("a") + } + from = table("T") + } + }, + expect("SELECT ALL VALUE a FROM T") { + exprSFW { + select = selectValue { + setq = SetQuantifier.ALL + constructor = v("a") + } + from = table("T") + } + }, + expect("SELECT DISTINCT VALUE a FROM T") { + exprSFW { + select = selectValue { + setq = SetQuantifier.DISTINCT + constructor = v("a") + } + from = table("T") + } + }, + expect("PIVOT a AT b FROM T") { + exprSFW { + select = selectPivot(v("a"), v("b")) + from = table("T") + } + }, + ) + + @JvmStatic + fun fromClauseCases() = listOf( + expect("SELECT a FROM T") { + exprSFW { + select = select("a") + from = fromValue { + expr = v("T") + type = From.Value.Type.SCAN + } + } + }, + expect("SELECT a FROM T AS x") { + exprSFW { + select = select("a") + from = fromValue { + expr = v("T") + type = From.Value.Type.SCAN + asAlias = id("x") + } + } + }, + expect("SELECT a FROM T AS x AT y") { + exprSFW { + select = select("a") + from = fromValue { + expr = v("T") + type = From.Value.Type.SCAN + asAlias = id("x") + atAlias = id("y") + } + } + }, + expect("SELECT a FROM T AS x AT y BY z") { + exprSFW { + select = select("a") + from = fromValue { + expr = v("T") + type = From.Value.Type.SCAN + asAlias = id("x") + atAlias = id("y") + byAlias = id("z") + } + } + }, + expect("SELECT a FROM UNPIVOT T") { + exprSFW { + select = select("a") + from = fromValue { + expr = v("T") + type = From.Value.Type.UNPIVOT + } + } + }, + expect("SELECT a FROM UNPIVOT T AS x") { + exprSFW { + select = select("a") + from = fromValue { + expr = v("T") + type = From.Value.Type.UNPIVOT + asAlias = id("x") + } + } + }, + expect("SELECT a FROM UNPIVOT T AS x AT y") { + exprSFW { + select = select("a") + from = fromValue { + expr = v("T") + type = From.Value.Type.UNPIVOT + asAlias = id("x") + atAlias = id("y") + } + } + }, + expect("SELECT a FROM UNPIVOT T AS x AT y BY z") { + exprSFW { + select = select("a") + from = fromValue { + expr = v("T") + type = From.Value.Type.UNPIVOT + asAlias = id("x") + atAlias = id("y") + byAlias = id("z") + } + } + }, + ) + + @JvmStatic + fun joinClauseCases() = listOf( + expect("SELECT a FROM T JOIN S") { + exprSFW { + select = select("a") + from = fromJoin { + lhs = table("T") + rhs = table("S") + } + } + }, + expect("SELECT a FROM T INNER JOIN S") { + exprSFW { + select = select("a") + from = fromJoin { + type = From.Join.Type.INNER + lhs = table("T") + rhs = table("S") + } + } + }, + // expect("SELECT a FROM T, S") { + // exprSFW { + // select = select("a") + // from = fromJoin { + // type = From.Join.Type.FULL + // lhs = table("T") + // rhs = table("S") + // } + // } + // }, + // expect("SELECT a FROM T CROSS JOIN S") { + // exprSFW { + // select = select("a") + // from = fromJoin { + // type = From.Join.Type.FULL + // lhs = table("T") + // rhs = table("S") + // } + // } + // }, + expect("SELECT a FROM T JOIN S ON NULL") { + exprSFW { + select = select("a") + from = fromJoin { + lhs = table("T") + rhs = table("S") + condition = NULL + } + } + }, + expect("SELECT a FROM T INNER JOIN S ON NULL") { + exprSFW { + select = select("a") + from = fromJoin { + type = From.Join.Type.INNER + lhs = table("T") + rhs = table("S") + condition = NULL + } + } + }, + ) + + // These are simple clauses + @JvmStatic + private fun otherClausesCases() = listOf( + expect("SELECT a FROM T LET x AS i") { + exprSFW { + select = select("a") + from = table("T") + let = let(mutableListOf()) { + bindings += letBinding(v("x"), id("i")) + } + } + }, + expect("SELECT a FROM T LET x AS i, y AS j") { + exprSFW { + select = select("a") + from = table("T") + let = let(mutableListOf()) { + bindings += letBinding(v("x"), id("i")) + bindings += letBinding(v("y"), id("j")) + } + } + }, + expect("SELECT a FROM T WHERE x") { + exprSFW { + select = select("a") + from = table("T") + where = v("x") + } + }, + expect("SELECT a FROM T LIMIT 1") { + exprSFW { + select = select("a") + from = table("T") + limit = exprLit(int32Value(1)) + } + }, + expect("SELECT a FROM T OFFSET 2") { + exprSFW { + select = select("a") + from = table("T") + offset = exprLit(int32Value(2)) + } + }, + expect("SELECT a FROM T LIMIT 1 OFFSET 2") { + exprSFW { + select = select("a") + from = table("T") + limit = exprLit(int32Value(1)) + offset = exprLit(int32Value(2)) + } + }, + expect("SELECT a FROM T GROUP BY x HAVING y") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(v("x")) + } + having = v("y") + } + }, + ) + + @JvmStatic + private fun groupByClauseCases() = listOf( + expect("SELECT a FROM T GROUP BY x") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(v("x")) + } + } + }, + expect("SELECT a FROM T GROUP BY x AS i") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(v("x"), id("i")) + } + } + }, + expect("SELECT a FROM T GROUP BY x, y") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(v("x")) + keys += groupByKey(v("y")) + } + } + }, + expect("SELECT a FROM T GROUP BY x AS i, y AS j") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(v("x"), id("i")) + keys += groupByKey(v("y"), id("j")) + } + } + }, + expect("SELECT a FROM T GROUP BY x GROUP AS g") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(v("x")) + asAlias = id("g") + } + } + }, + expect("SELECT a FROM T GROUP BY x AS i GROUP AS g") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(v("x"), id("i")) + asAlias = id("g") + } + } + }, + expect("SELECT a FROM T GROUP BY x, y GROUP AS g") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(v("x")) + keys += groupByKey(v("y")) + asAlias = id("g") + } + } + }, + expect("SELECT a FROM T GROUP BY x AS i, y AS j GROUP AS g") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.FULL + keys += groupByKey(v("x"), id("i")) + keys += groupByKey(v("y"), id("j")) + asAlias = id("g") + } + } + }, + expect("SELECT a FROM T GROUP PARTIAL BY x") { + exprSFW { + select = select("a") + from = table("T") + groupBy = groupBy { + strategy = GroupBy.Strategy.PARTIAL + keys += groupByKey(v("x")) + } + } + }, + ) + + @JvmStatic + private fun orderByClauseCases() = listOf( + expect("SELECT a FROM T ORDER BY x") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), null, null) + } + } + }, + expect("SELECT a FROM T ORDER BY x ASC") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), Sort.Dir.ASC, null) + } + } + }, + expect("SELECT a FROM T ORDER BY x DESC") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), Sort.Dir.DESC, null) + } + } + }, + expect("SELECT a FROM T ORDER BY x NULLS FIRST") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), null, Sort.Nulls.FIRST) + } + } + }, + expect("SELECT a FROM T ORDER BY x NULLS LAST") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), null, Sort.Nulls.LAST) + } + } + }, + expect("SELECT a FROM T ORDER BY x ASC NULLS FIRST") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), Sort.Dir.ASC, Sort.Nulls.FIRST) + } + } + }, + expect("SELECT a FROM T ORDER BY x ASC NULLS LAST") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), Sort.Dir.ASC, Sort.Nulls.LAST) + } + } + }, + expect("SELECT a FROM T ORDER BY x DESC NULLS FIRST") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), Sort.Dir.DESC, Sort.Nulls.FIRST) + } + } + }, + expect("SELECT a FROM T ORDER BY x DESC NULLS LAST") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), Sort.Dir.DESC, Sort.Nulls.LAST) + } + } + }, + expect("SELECT a FROM T ORDER BY x, y") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), null, null) + sorts += sort(v("y"), null, null) + } + } + }, + expect("SELECT a FROM T ORDER BY x ASC, y DESC") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), Sort.Dir.ASC, null) + sorts += sort(v("y"), Sort.Dir.DESC, null) + } + } + }, + expect("SELECT a FROM T ORDER BY x NULLS FIRST, y NULLS LAST") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), null, Sort.Nulls.FIRST) + sorts += sort(v("y"), null, Sort.Nulls.LAST) + } + } + }, + expect("SELECT a FROM T ORDER BY x ASC NULLS FIRST, y DESC NULLS LAST") { + exprSFW { + select = select("a") + from = table("T") + orderBy = orderBy { + sorts += sort(v("x"), Sort.Dir.ASC, Sort.Nulls.FIRST) + sorts += sort(v("y"), Sort.Dir.DESC, Sort.Nulls.LAST) + } + } + }, + ) + + @JvmStatic + fun unionClauseCases() = listOf( + expect("SELECT a FROM T UNION (SELECT b FROM S)") { + exprSFW { + select = select("a") + from = table("T") + setOp = exprSFWSetOp { + type = setOp(SetOp.Type.UNION, null) + operand = exprSFW { + select = select("b") + from = table("S") + } + } + } + }, + expect("SELECT a FROM T UNION ALL (SELECT b FROM S)") { + exprSFW { + select = select("a") + from = table("T") + setOp = exprSFWSetOp { + type = setOp(SetOp.Type.UNION, SetQuantifier.ALL) + operand = exprSFW { + select = select("b") + from = table("S") + } + } + } + }, + expect("SELECT a FROM T UNION DISTINCT (SELECT b FROM S)") { + exprSFW { + select = select("a") + from = table("T") + setOp = exprSFWSetOp { + type = setOp(SetOp.Type.UNION, SetQuantifier.DISTINCT) + operand = exprSFW { + select = select("b") + from = table("S") + } + } + } + }, + expect("SELECT a FROM T UNION (SELECT b FROM S) LIMIT 1") { + exprSFW { + select = select("a") + from = table("T") + setOp = exprSFWSetOp { + type = setOp(SetOp.Type.UNION, null) + operand = exprSFW { + select = select("b") + from = table("S") + } + } + limit = exprLit(int32Value(1)) + } + }, + expect("SELECT a FROM T UNION (SELECT b FROM S LIMIT 1)") { + exprSFW { + select = select("a") + from = table("T") + setOp = exprSFWSetOp { + type = setOp(SetOp.Type.UNION, null) + operand = exprSFW { + select = select("b") + from = table("S") + limit = exprLit(int32Value(1)) + } + } + } + }, + expect("SELECT a FROM T UNION (SELECT b FROM S) ORDER BY x") { + exprSFW { + select = select("a") + from = table("T") + setOp = exprSFWSetOp { + type = setOp(SetOp.Type.UNION, null) + operand = exprSFW { + select = select("b") + from = table("S") + } + } + orderBy = orderBy { + sorts += sort(v("x"), null, null) + } + } + }, + expect("SELECT a FROM T UNION (SELECT b FROM S ORDER BY x)") { + exprSFW { + select = select("a") + from = table("T") + setOp = exprSFWSetOp { + type = setOp(SetOp.Type.UNION, null) + operand = exprSFW { + select = select("b") + from = table("S") + orderBy = orderBy { + sorts += sort(v("x"), null, null) + } + } + } + } + }, + ) + + private fun expect(expected: String, block: AstBuilder.() -> AstNode): Case { + val i = ast(AstFactory.DEFAULT, block) + return Case.Success(i, expected) + } + + private fun fail(message: String, block: AstBuilder.() -> AstNode): Case { + val i = ast(AstFactory.DEFAULT, block) + return Case.Fail(i, message) + } + + // DSL shorthand + + private fun AstBuilder.v(symbol: String) = this.exprVar { + identifier = id(symbol) + scope = Expr.Var.Scope.DEFAULT + } + + private fun AstBuilder.id( + symbol: String, + case: Identifier.CaseSensitivity = Identifier.CaseSensitivity.INSENSITIVE, + ) = this.identifierSymbol(symbol, case) + + private fun AstBuilder.select(vararg s: String) = selectProject { + s.forEach { + items += selectProjectItemExpression(v(it)) + } + } + + private fun AstBuilder.table(symbol: String) = fromValue { + expr = v(symbol) + type = From.Value.Type.SCAN + } + } + + sealed class Case { + + abstract fun assert() + + class Success( + private val input: AstNode, + private val expected: String, + ) : Case() { + + override fun assert() { + val actual = input.sql(SqlLayout.ONELINE) + Assertions.assertEquals(expected, actual) + } + } + + class Fail( + private val input: AstNode, + private val message: String, + ) : Case() { + + override fun assert() { + assertFails(message) { + input.sql(SqlLayout.ONELINE) + } + } + } + } +} diff --git a/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueTextWriter.kt b/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueTextWriter.kt index 67e56db20..75d59c06f 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueTextWriter.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueTextWriter.kt @@ -47,7 +47,7 @@ import java.io.PrintStream * @property indent Indent prefix, default is 2-spaces */ @PartiQLValueExperimental -internal class PartiQLValueTextWriter( +public class PartiQLValueTextWriter( private val out: PrintStream, private val formatted: Boolean = true, private val indent: String = " ",