diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt index 5ddef56c1..3a1ce7f77 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt @@ -1,11 +1,11 @@ package org.partiql.cli.pipeline -import org.partiql.ast.Statement +import org.partiql.ast.v1.Statement import org.partiql.cli.ErrorCodeString import org.partiql.eval.Mode import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParser -import org.partiql.parser.PartiQLParserBuilder +import org.partiql.parser.V1PartiQLParser +import org.partiql.parser.V1PartiQLParserBuilder import org.partiql.plan.Plan import org.partiql.planner.PartiQLPlanner import org.partiql.spi.Context @@ -13,10 +13,9 @@ import org.partiql.spi.catalog.Session import org.partiql.spi.errors.PErrorListenerException import org.partiql.spi.value.Datum import java.io.PrintStream -import kotlin.jvm.Throws internal class Pipeline private constructor( - private val parser: PartiQLParser, + private val parser: V1PartiQLParser, private val planner: PartiQLPlanner, private val compiler: PartiQLCompiler, private val ctx: Context, @@ -81,7 +80,7 @@ internal class Pipeline private constructor( private fun create(mode: Mode, out: PrintStream, config: Config): Pipeline { val listener = config.getErrorListener(out) val ctx = Context.of(listener) - val parser = PartiQLParserBuilder().build() + val parser = V1PartiQLParserBuilder().build() val planner = PartiQLPlanner.builder().build() val compiler = PartiQLCompiler.builder().build() return Pipeline(parser, planner, compiler, ctx, mode) diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt index 158350ebc..52e850da0 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt @@ -9,7 +9,7 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.partiql.eval.Mode import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Plan import org.partiql.planner.PartiQLPlanner import org.partiql.plugins.memory.MemoryCatalog @@ -1307,7 +1307,7 @@ class PartiQLEvaluatorTest { ) { private val compiler = PartiQLCompiler.standard() - private val parser = PartiQLParser.standard() + private val parser = V1PartiQLParser.standard() private val planner = PartiQLPlanner.standard() /** @@ -1373,7 +1373,7 @@ class PartiQLEvaluatorTest { ) { private val compiler = PartiQLCompiler.standard() - private val parser = PartiQLParser.standard() + private val parser = V1PartiQLParser.standard() private val planner = PartiQLPlanner.standard() internal fun assert() { diff --git a/partiql-planner/api/partiql-planner.api b/partiql-planner/api/partiql-planner.api index 1495fe28c..47f81855e 100644 --- a/partiql-planner/api/partiql-planner.api +++ b/partiql-planner/api/partiql-planner.api @@ -1,8 +1,8 @@ public abstract interface class org/partiql/planner/PartiQLPlanner { public static final field Companion Lorg/partiql/planner/PartiQLPlanner$Companion; public static fun builder ()Lorg/partiql/planner/builder/PartiQLPlannerBuilder; - public abstract fun plan (Lorg/partiql/ast/Statement;Lorg/partiql/spi/catalog/Session;)Lorg/partiql/planner/PartiQLPlanner$Result; - public abstract fun plan (Lorg/partiql/ast/Statement;Lorg/partiql/spi/catalog/Session;Lorg/partiql/spi/Context;)Lorg/partiql/planner/PartiQLPlanner$Result; + public abstract fun plan (Lorg/partiql/ast/v1/Statement;Lorg/partiql/spi/catalog/Session;)Lorg/partiql/planner/PartiQLPlanner$Result; + public abstract fun plan (Lorg/partiql/ast/v1/Statement;Lorg/partiql/spi/catalog/Session;Lorg/partiql/spi/Context;)Lorg/partiql/planner/PartiQLPlanner$Result; public static fun standard ()Lorg/partiql/planner/PartiQLPlanner; } @@ -12,7 +12,7 @@ public final class org/partiql/planner/PartiQLPlanner$Companion { } public final class org/partiql/planner/PartiQLPlanner$DefaultImpls { - public static fun plan (Lorg/partiql/planner/PartiQLPlanner;Lorg/partiql/ast/Statement;Lorg/partiql/spi/catalog/Session;)Lorg/partiql/planner/PartiQLPlanner$Result; + public static fun plan (Lorg/partiql/planner/PartiQLPlanner;Lorg/partiql/ast/v1/Statement;Lorg/partiql/spi/catalog/Session;)Lorg/partiql/planner/PartiQLPlanner$Result; } public final class org/partiql/planner/PartiQLPlanner$Result { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt index b393eb0b8..4acacbf4f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt @@ -1,6 +1,6 @@ package org.partiql.planner -import org.partiql.ast.Statement +import org.partiql.ast.v1.Statement import org.partiql.plan.Plan import org.partiql.planner.builder.PartiQLPlannerBuilder import org.partiql.spi.Context diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt index 34bd00937..3bbd30250 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt @@ -1,15 +1,15 @@ package org.partiql.planner.internal -import org.partiql.ast.Statement -import org.partiql.ast.normalize.normalize +import org.partiql.ast.v1.Statement import org.partiql.plan.Operation import org.partiql.plan.Plan import org.partiql.plan.builder.PlanFactory import org.partiql.plan.rex.Rex import org.partiql.planner.PartiQLPlanner import org.partiql.planner.PartiQLPlannerPass -import org.partiql.planner.internal.transforms.AstToPlan +import org.partiql.planner.internal.normalize.normalize import org.partiql.planner.internal.transforms.PlanTransform +import org.partiql.planner.internal.transforms.V1AstToPlan import org.partiql.planner.internal.typer.PlanTyper import org.partiql.spi.Context import org.partiql.spi.catalog.Session @@ -39,7 +39,7 @@ internal class SqlPlanner( val ast = statement.normalize() // 2. AST to Rel/Rex - val root = AstToPlan.apply(ast, env) + val root = V1AstToPlan.apply(ast, env) // 3. Resolve variables val typer = PlanTyper(env, ctx) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt new file mode 100644 index 000000000..bda6153cb --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package org.partiql.planner.internal.transforms + +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.Query +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.planner.internal.Env +import org.partiql.planner.internal.ir.statementQuery +import org.partiql.spi.catalog.Identifier +import org.partiql.ast.v1.Identifier as AstIdentifier +import org.partiql.ast.v1.IdentifierChain as AstIdentifierChain +import org.partiql.ast.v1.Statement as AstStatement +import org.partiql.planner.internal.ir.Statement as PlanStatement + +/** + * Simple translation from AST to an unresolved algebraic IR. + */ +internal object V1AstToPlan { + + // statement.toPlan() + @JvmStatic + fun apply(statement: AstStatement, env: Env): PlanStatement = statement.accept(ToPlanStatement, env) + + @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") + private object ToPlanStatement : AstVisitor { + + override fun defaultReturn(node: AstNode, env: Env) = throw IllegalArgumentException("Unsupported statement") + + override fun visitQuery(node: Query, env: Env): PlanStatement { + val rex = when (val expr = node.expr) { + is ExprQuerySet -> V1RelConverter.apply(expr, env) + else -> V1RexConverter.apply(expr, env) + } + return statementQuery(rex) + } + } + + // --- Helpers -------------------- + + fun convert(identifier: AstIdentifierChain): Identifier { + val parts = mutableListOf() + parts.add(part(identifier.root)) + var curStep = identifier.next + while (curStep != null) { + parts.add(part(curStep.root)) + curStep = curStep.next + } + return Identifier.of(parts) + } + + fun convert(identifier: AstIdentifier): Identifier { + return Identifier.of(part(identifier)) + } + + fun part(identifier: AstIdentifier): Identifier.Part = when (identifier.isDelimited) { + true -> Identifier.Part.delimited(identifier.symbol) + false -> Identifier.Part.regular(identifier.symbol) + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt new file mode 100644 index 000000000..4a95a4174 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1NormalizeSelect.kt @@ -0,0 +1,396 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.planner.internal.transforms + +import org.partiql.ast.v1.Ast.exprCall +import org.partiql.ast.v1.Ast.exprCase +import org.partiql.ast.v1.Ast.exprCaseBranch +import org.partiql.ast.v1.Ast.exprIsType +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprQuerySet +import org.partiql.ast.v1.Ast.exprStruct +import org.partiql.ast.v1.Ast.exprStructField +import org.partiql.ast.v1.Ast.exprVarRef +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.identifierChain +import org.partiql.ast.v1.Ast.queryBodySFW +import org.partiql.ast.v1.Ast.queryBodySetOp +import org.partiql.ast.v1.Ast.selectItemExpr +import org.partiql.ast.v1.Ast.selectList +import org.partiql.ast.v1.Ast.selectValue +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.DataType +import org.partiql.ast.v1.From +import org.partiql.ast.v1.FromExpr +import org.partiql.ast.v1.FromJoin +import org.partiql.ast.v1.FromTableRef +import org.partiql.ast.v1.GroupBy +import org.partiql.ast.v1.QueryBody +import org.partiql.ast.v1.SelectItem +import org.partiql.ast.v1.SelectList +import org.partiql.ast.v1.SelectStar +import org.partiql.ast.v1.SelectValue +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprCase +import org.partiql.ast.v1.expr.ExprLit +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.ExprStruct +import org.partiql.ast.v1.expr.ExprVarRef +import org.partiql.ast.v1.expr.Scope +import org.partiql.planner.internal.helpers.toBinder +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.stringValue + +/** + * Converts SQL-style SELECT to PartiQL SELECT VALUE. + * - If there is a PROJECT ALL, we use the TUPLEUNION. + * - If there is NOT a PROJECT ALL, we use a literal struct. + * + * Here are some example of rewrites: + * + * ``` + * SELECT * + * FROM + * A AS x, + * B AS y AT i + * ``` + * gets rewritten to: + * ``` + * SELECT VALUE TUPLEUNION( + * CASE WHEN x IS STRUCT THEN x ELSE { '_1': x }, + * CASE WHEN y IS STRUCT THEN y ELSE { '_2': y }, + * { 'i': i } + * ) FROM A AS x, B AS y AT i + * ``` + * + * ``` + * SELECT x.*, x.a FROM A AS x + * ``` + * gets rewritten to: + * ``` + * SELECT VALUE TUPLEUNION( + * CASE WHEN x IS STRUCT THEN x ELSE { '_1': x }, + * { 'a': x.a } + * ) FROM A AS x + * ``` + * + * ``` + * SELECT x.a FROM A AS x + * ``` + * gets rewritten to: + * ``` + * SELECT VALUE { + * 'a': x.a + * } FROM A AS x + * ``` + * + * NOTE: This does NOT transform subqueries. It operates directly on an [QueryExpr.SFW] -- and that is it. Therefore: + * ``` + * SELECT + * (SELECT 1 FROM T AS "T") + * FROM R AS "R" + * ``` + * will be transformed to: + * ``` + * SELECT VALUE { + * '_1': (SELECT 1 FROM T AS "T") -- notice that SELECT 1 didn't get transformed. + * } FROM R AS "R" + * ``` + * + * Requires [NormalizeFromSource]. + */ +internal object V1NormalizeSelect { + + internal fun normalize(node: ExprQuerySet): ExprQuerySet { + return when (val body = node.body) { + is QueryBody.SFW -> { + val sfw = Visitor.visitSFW(body, newCtx()) + exprQuerySet( + body = sfw, + orderBy = node.orderBy, + limit = node.limit, + offset = node.offset + ) + } + is QueryBody.SetOp -> { + val lhs = body.lhs.normalizeOrIdentity() + val rhs = body.rhs.normalizeOrIdentity() + exprQuerySet( + body = queryBodySetOp( + type = body.type, + isOuter = body.isOuter, + lhs = lhs, + rhs = rhs + ), + orderBy = node.orderBy, + limit = node.limit, + offset = node.offset + ) + } + else -> TODO() // TODO ALAN + } + } + + private fun Expr.normalizeOrIdentity(): Expr { + return when (this) { + is ExprQuerySet -> normalize(this) + else -> this + } + } + + /** + * Closure for incrementing a derived binding counter + */ + private fun newCtx(): () -> Int = run { + var i = 1; + { i++ } + } + + /** + * The type parameter () -> Int + */ + private object Visitor : AstVisitor Int> { + + /** + * This is used to give projections a name. For example: + * ``` + * SELECT t.* FROM t AS t + * ``` + * + * Will get converted into: + * ``` + * SELECT VALUE TUPLEUNION( + * CASE + * WHEN t IS STRUCT THEN t + * ELSE { '_1': t } + * END + * ) + * FROM t AS t + * ``` + * + * In order to produce the struct's key in `{ '_1': t }` above, we use [col] to produce the column name + * given the ordinal. + */ + private val col = { index: Int -> "_${index + 1}" } + + internal fun visitSFW(node: QueryBody.SFW, ctx: () -> Int): QueryBody.SFW { + val sfw = super.visitQueryBodySFW(node, ctx) as QueryBody.SFW + return when (val select = sfw.select) { + is SelectStar -> { + val selectValue = when (val group = sfw.groupBy) { + null -> visitSelectAll(select, sfw.from) + else -> visitSelectAll(select, group) + } + queryBodySFW( + select = selectValue, + exclude = sfw.exclude, + from = sfw.from, + let = sfw.let, + where = sfw.where, + groupBy = sfw.groupBy, + having = sfw.having, + ) + } + else -> sfw + } + } + + override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: () -> Int): QueryBody.SFW { + return node + } + + override fun visitSelectList(node: SelectList, ctx: () -> Int): SelectValue { + + // Visit items, adding a binder if necessary + var diff = false + val visitedItems = ArrayList(node.items.size) + node.items.forEach { n -> + val item = visitSelectItem(n, ctx) as SelectItem + if (item !== n) diff = true + visitedItems.add(item) + } + val visitedNode = if (diff) selectList(visitedItems, node.setq) else node + + // Rewrite selection + return when (node.items.any { it is SelectItem.Star }) { + false -> visitSelectProjectWithoutProjectAll(visitedNode) + true -> visitSelectProjectWithProjectAll(visitedNode) + } + } + + override fun visitSelectItemExpr(node: SelectItem.Expr, ctx: () -> Int): SelectItem.Expr { + val expr = visitExpr(node.expr, newCtx()) as Expr + val alias = when (node.asAlias) { + null -> expr.toBinder(ctx) + else -> node.asAlias + } + return if (expr != node.expr || alias != node.asAlias) { + selectItemExpr(expr, alias) + } else { + node + } + } + + // Helpers + + /** + * We need to call this from [visitExprSFW] and not override [visitSelectStar] because we need access to the + * [From] aliases. + * + * Note: We assume that [select] and [from] have already been visited. + */ + private fun visitSelectAll(select: SelectStar, from: From): SelectValue { + val tupleUnionArgs = from.tableRefs.flatMap { it.aliases() }.flatMapIndexed { i, binding -> + val asAlias = binding.first + val atAlias = binding.second + val atAliasItem = atAlias?.simple()?.let { + val alias = it.asAlias ?: error("The AT alias should be present. This wasn't normalized.") + buildSimpleStruct(it.expr, alias.symbol) + } + listOfNotNull( + buildCaseWhenStruct(asAlias.star(i).expr, i), + atAliasItem, + ) + } + return selectValue( + constructor = exprCall( + function = identifierChain(identifier("TUPLEUNION", isDelimited = true), next = null), + args = tupleUnionArgs, + setq = null // setq = null for scalar fn + ), + setq = select.setq + ) + } + + /** + * We need to call this from [visitExprSFW] and not override [visitSelectStar] because we need access to the + * [GroupBy] aliases. + * + * Note: We assume that [select] and [group] have already been visited. + */ + private fun visitSelectAll(select: SelectStar, group: GroupBy): SelectValue { + val groupAs = group.asAlias?.let { structField(it.symbol, varLocal(it.symbol)) } + val fields = group.keys.map { key -> + val alias = key.asAlias ?: error("Expected a GROUP BY alias.") + structField(alias.symbol, varLocal(alias.symbol)) + } + listOfNotNull(groupAs) + val constructor = exprStruct(fields) + return selectValue( + constructor = constructor, + setq = select.setq + ) + } + + private fun visitSelectProjectWithProjectAll(node: SelectList): SelectValue { + val tupleUnionArgs = node.items.mapIndexed { index, item -> + when (item) { + is SelectItem.Star -> buildCaseWhenStruct(item.expr, index) + is SelectItem.Expr -> buildSimpleStruct( + item.expr, + item.asAlias?.symbol + ?: error("The alias should've been here. This AST is not normalized.") + ) + else -> TODO() // TODO ALAN + } + } + return selectValue( + setq = node.setq, + constructor = exprCall( + function = identifierChain(identifier("TUPLEUNION", isDelimited = true), next = null), + args = tupleUnionArgs, + setq = null // setq = null for scalar fn + ) + ) + } + + @OptIn(PartiQLValueExperimental::class) + private fun visitSelectProjectWithoutProjectAll(node: SelectList): SelectValue { + val structFields = node.items.map { item -> + val itemExpr = item as? SelectItem.Expr ?: error("Expected the projection to be an expression.") + exprStructField( + name = exprLit(stringValue(itemExpr.asAlias?.symbol!!)), + value = item.expr + ) + } + return selectValue( + setq = node.setq, + constructor = exprStruct( + fields = structFields + ) + ) + } + + private fun buildCaseWhenStruct(expr: Expr, index: Int): ExprCase = exprCase( + expr = null, + branches = listOf( + exprCaseBranch( + condition = exprIsType(expr, DataType.STRUCT(), not = false), + expr = expr + ) + ), + defaultExpr = buildSimpleStruct(expr, col(index)) + ) + + @OptIn(PartiQLValueExperimental::class) + private fun buildSimpleStruct(expr: Expr, name: String): ExprStruct = exprStruct( + fields = listOf( + exprStructField( + name = exprLit(stringValue(name)), + value = expr + ) + ) + ) + + @OptIn(PartiQLValueExperimental::class) + private fun structField(name: String, expr: Expr): ExprStruct.Field = exprStructField( + name = ExprLit(stringValue(name)), + value = expr + ) + + private fun varLocal(name: String): ExprVarRef = exprVarRef( + identifierChain = identifierChain(identifier(name, isDelimited = true), next = null), + scope = Scope.LOCAL() + ) + + private fun FromTableRef.aliases(): List> = when (this) { + is FromJoin -> lhs.aliases() + rhs.aliases() + is FromExpr -> { + val asAlias = asAlias?.symbol ?: error("AST not normalized, missing asAlias on FROM source.") + val atAlias = atAlias?.symbol + listOf(Pair(asAlias, atAlias)) + } + else -> TODO() // TODO ALAN + } + + // t -> t.* AS _i + private fun String.star(i: Int): SelectItem.Expr { + val expr = exprVarRef(identifierChain(id(this), next = null), Scope.DEFAULT()) + val alias = expr.toBinder(i) + return selectItemExpr(expr, alias) + } + + // t -> t AS t + private fun String.simple(): SelectItem.Expr { + val expr = exprVarRef(identifierChain(id(this), next = null), Scope.DEFAULT()) + val alias = id(this) + return selectItemExpr(expr, alias) + } + + private fun id(symbol: String) = identifier(symbol, isDelimited = false) + + override fun defaultReturn(node: AstNode, ctx: () -> Int) = node + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt new file mode 100644 index 000000000..ef7925428 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt @@ -0,0 +1,756 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package org.partiql.planner.internal.transforms + +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprVarRef +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.identifierChain +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.Exclude +import org.partiql.ast.v1.ExcludeStep +import org.partiql.ast.v1.From +import org.partiql.ast.v1.FromExpr +import org.partiql.ast.v1.FromJoin +import org.partiql.ast.v1.FromType +import org.partiql.ast.v1.GroupBy +import org.partiql.ast.v1.GroupByStrategy +import org.partiql.ast.v1.IdentifierChain +import org.partiql.ast.v1.JoinType +import org.partiql.ast.v1.Nulls +import org.partiql.ast.v1.Order +import org.partiql.ast.v1.OrderBy +import org.partiql.ast.v1.QueryBody +import org.partiql.ast.v1.SelectItem +import org.partiql.ast.v1.SelectList +import org.partiql.ast.v1.SelectPivot +import org.partiql.ast.v1.SelectStar +import org.partiql.ast.v1.SelectValue +import org.partiql.ast.v1.SetOpType +import org.partiql.ast.v1.SetQuantifier +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprCall +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.Scope +import org.partiql.planner.internal.Env +import org.partiql.planner.internal.helpers.toBinder +import org.partiql.planner.internal.ir.Rel +import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.rel +import org.partiql.planner.internal.ir.relBinding +import org.partiql.planner.internal.ir.relOpAggregate +import org.partiql.planner.internal.ir.relOpAggregateCallUnresolved +import org.partiql.planner.internal.ir.relOpDistinct +import org.partiql.planner.internal.ir.relOpErr +import org.partiql.planner.internal.ir.relOpExclude +import org.partiql.planner.internal.ir.relOpExcludePath +import org.partiql.planner.internal.ir.relOpExcludeStep +import org.partiql.planner.internal.ir.relOpExcludeTypeCollIndex +import org.partiql.planner.internal.ir.relOpExcludeTypeCollWildcard +import org.partiql.planner.internal.ir.relOpExcludeTypeStructKey +import org.partiql.planner.internal.ir.relOpExcludeTypeStructSymbol +import org.partiql.planner.internal.ir.relOpExcludeTypeStructWildcard +import org.partiql.planner.internal.ir.relOpFilter +import org.partiql.planner.internal.ir.relOpJoin +import org.partiql.planner.internal.ir.relOpLimit +import org.partiql.planner.internal.ir.relOpOffset +import org.partiql.planner.internal.ir.relOpProject +import org.partiql.planner.internal.ir.relOpScan +import org.partiql.planner.internal.ir.relOpScanIndexed +import org.partiql.planner.internal.ir.relOpSort +import org.partiql.planner.internal.ir.relOpSortSpec +import org.partiql.planner.internal.ir.relOpUnpivot +import org.partiql.planner.internal.ir.relType +import org.partiql.planner.internal.ir.rex +import org.partiql.planner.internal.ir.rexOpLit +import org.partiql.planner.internal.ir.rexOpPivot +import org.partiql.planner.internal.ir.rexOpSelect +import org.partiql.planner.internal.ir.rexOpStruct +import org.partiql.planner.internal.ir.rexOpStructField +import org.partiql.planner.internal.ir.rexOpVarLocal +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.types.PType +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.boolValue +import org.partiql.value.int32Value +import org.partiql.value.stringValue + +/** + * Lexically scoped state for use in translating an individual SELECT statement. + */ +internal object V1RelConverter { + + // IGNORE — so we don't have to non-null assert on operator inputs + internal val nil = rel(relType(emptyList(), emptySet()), relOpErr("nil")) + + /** + * Here we convert an SFW to composed [Rel]s, then apply the appropriate relation-value projection to get a [Rex]. + */ + internal fun apply(qSet: ExprQuerySet, env: Env): Rex { + val newQSet = V1NormalizeSelect.normalize(qSet) + val rex = when (val body = newQSet.body) { + is QueryBody.SFW -> { + val rel = newQSet.accept(ToRel(env), nil) + when (val projection = body.select) { + // PIVOT ... FROM + is SelectPivot -> { + val key = projection.key.toRex(env) + val value = projection.value.toRex(env) + val type = (STRUCT) + val op = rexOpPivot(key, value, rel) + rex(type, op) + } + // SELECT VALUE ... FROM + is SelectValue -> { + assert(rel.type.schema.size == 1) { + "Expected SELECT VALUE's input to have a single binding. " + + "However, it contained: ${rel.type.schema.map { it.name }}." + } + val constructor = rex(ANY, rexOpVarLocal(0, 0)) + val op = rexOpSelect(constructor, rel) + val type = when (rel.type.props.contains(Rel.Prop.ORDERED)) { + true -> (LIST) + else -> (BAG) + } + rex(type, op) + } + // SELECT * FROM + is SelectStar -> { + throw IllegalArgumentException("AST not normalized") + } + // SELECT ... FROM + is SelectList -> { + throw IllegalArgumentException("AST not normalized") + } + + else -> TODO() // TODO ALAN + } + } + is QueryBody.SetOp -> { + val rel = newQSet.accept(ToRel(env), nil) + val constructor = rex(ANY, rexOpVarLocal(0, 0)) + val op = rexOpSelect(constructor, rel) + val type = when (rel.type.props.contains(Rel.Prop.ORDERED)) { + true -> (LIST) + else -> (BAG) + } + rex(type, op) + } + + else -> TODO() // TODO ALAN + } + return rex + } + + /** + * Syntax sugar for converting an [Expr] tree to a [Rex] tree. + */ + private fun Expr.toRex(env: Env): Rex = V1RexConverter.apply(this, env) + + @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE", "LocalVariableName") + internal class ToRel(private val env: Env) : AstVisitor { + + override fun defaultReturn(node: AstNode, input: Rel): Rel = + throw IllegalArgumentException("unsupported rel $node") + + /** + * Translate SFW AST node to a pipeline of [Rel] operators; skip any SELECT VALUE or PIVOT projection. + */ + + override fun visitExprQuerySet(node: ExprQuerySet, ctx: Rel): Rel { + val body = node.body + val orderBy = node.orderBy + val limit = node.limit + val offset = node.offset + when (body) { + is QueryBody.SFW -> { + var sel = body + var rel = visitFrom(sel.from, nil) + rel = convertWhere(rel, sel.where) + // kotlin does not have destructuring reassignment + val (_sel, _rel) = convertAgg(rel, sel, sel.groupBy) + sel = _sel + rel = _rel + // Plan.create (possibly rewritten) sel node + rel = convertHaving(rel, sel.having) + rel = convertOrderBy(rel, orderBy) + // offset should precede limit + rel = convertOffset(rel, offset) + rel = convertLimit(rel, limit) + rel = convertExclude(rel, sel.exclude) + // append SQL projection if present + rel = when (val projection = sel.select) { + is SelectValue -> { + val project = visitSelectValue(projection, rel) + visitSetQuantifier(projection.setq, project) + } + is SelectStar, is SelectList -> { + error("AST not normalized, found ${projection.javaClass.simpleName}") + } + is SelectPivot -> rel // Skip PIVOT + else -> TODO() // TODO ALAN + } + return rel + } + is QueryBody.SetOp -> { + var rel = convertSetOp(body) + rel = convertOrderBy(rel, orderBy) + // offset should precede limit + rel = convertOffset(rel, offset) + rel = convertLimit(rel, limit) + return rel + } + else -> TODO() // TODO ALAN + } + } + + /** + * Given a [setQuantifier], this will return a [Rel] of [Rel.Op.Distinct] wrapping the [input]. + * If [setQuantifier] is null or ALL, this will return the [input]. + */ + private fun visitSetQuantifier(setQuantifier: SetQuantifier?, input: Rel): Rel { + return when (setQuantifier?.code()) { + SetQuantifier.DISTINCT -> rel(input.type, relOpDistinct(input)) + SetQuantifier.ALL, null -> input + else -> TODO() // TODO ALAN + } + } + + override fun visitSelectList(node: SelectList, input: Rel): Rel { + // this ignores aggregations + val schema = mutableListOf() + val props = input.type.props + val projections = mutableListOf() + node.items.forEach { + val (binding, projection) = convertSelectItem(it) + schema.add(binding) + projections.add(projection) + } + val type = relType(schema, props) + val op = relOpProject(input, projections) + return rel(type, op) + } + + override fun visitSelectValue(node: SelectValue, input: Rel): Rel { + val name = node.constructor.toBinder(1).symbol + val rex = V1RexConverter.apply(node.constructor, env) + val schema = listOf(relBinding(name, rex.type)) + val props = input.type.props + val type = relType(schema, props) + val op = relOpProject(input, projections = listOf(rex)) + return rel(type, op) + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitFrom(node: From, ctx: Rel): Rel { + val tableRefs = node.tableRefs.map { visitTableRef(it, ctx) } + return tableRefs.drop(1).fold(tableRefs.first()) { acc, tRef -> + val joinType = Rel.Op.Join.Type.INNER + val condition = rex(BOOL, rexOpLit(boolValue(true))) + val schema = acc.type.schema + tRef.type.schema + val props = emptySet() + val type = relType(schema, props) + val op = relOpJoin(acc, tRef, condition, joinType) + rel(type, op) + } + } + + override fun visitFromExpr(node: FromExpr, nil: Rel): Rel { + val rex = V1RexConverter.applyRel(node.expr, env) + val binding = when (val a = node.asAlias) { + null -> error("AST not normalized, missing AS alias on $node") + else -> relBinding( + name = a.symbol, + type = rex.type + ) + } + return when (node.fromType.code()) { + FromType.SCAN -> { + when (val i = node.atAlias) { + null -> convertScan(rex, binding) + else -> { + val index = relBinding( + name = i.symbol, + type = (INT) + ) + convertScanIndexed(rex, binding, index) + } + } + } + FromType.UNPIVOT -> { + val atAlias = when (val at = node.atAlias) { + null -> error("AST not normalized, missing AT alias on UNPIVOT $node") + else -> relBinding( + name = at.symbol, + type = (STRING) + ) + } + convertUnpivot(rex, k = atAlias, v = binding) + } + else -> TODO() // TODO ALAN + } + } + + /** + * Appends [Rel.Op.Join] where the left and right sides are converted FROM sources + * + * TODO compute basic schema + */ + @OptIn(PartiQLValueExperimental::class) + override fun visitFromJoin(node: FromJoin, nil: Rel): Rel { + val lhs = visitTableRef(node.lhs, nil) + val rhs = visitTableRef(node.rhs, nil) + val schema = lhs.type.schema + rhs.type.schema // Note: This gets more specific in PlanTyper. It is only used to find binding names here. + val props = emptySet() + val condition = node.condition?.let { V1RexConverter.apply(it, env) } ?: rex(BOOL, rexOpLit(boolValue(true))) + val joinType = when (node.joinType?.code()) { + JoinType.LEFT_OUTER, JoinType.LEFT, JoinType.LEFT_CROSS -> Rel.Op.Join.Type.LEFT + JoinType.RIGHT_OUTER, JoinType.RIGHT -> Rel.Op.Join.Type.RIGHT + JoinType.FULL_OUTER, JoinType.FULL -> Rel.Op.Join.Type.FULL + JoinType.INNER, + JoinType.CROSS -> Rel.Op.Join.Type.INNER // Cross Joins are just INNER JOIN ON TRUE + null -> Rel.Op.Join.Type.INNER // a JOIN b ON a.id = b.id <--> a INNER JOIN b ON a.id = b.id + else -> TODO() // TODO ALAN + } + val type = relType(schema, props) + val op = relOpJoin(lhs, rhs, condition, joinType) + return rel(type, op) + } + + // Helpers + private fun convertScan(rex: Rex, binding: Rel.Binding): Rel { + val schema = listOf(binding) + val props = emptySet() + val type = relType(schema, props) + val op = relOpScan(rex) + return rel(type, op) + } + + private fun convertScanIndexed(rex: Rex, binding: Rel.Binding, index: Rel.Binding): Rel { + val schema = listOf(binding, index) + val props = emptySet() + val type = relType(schema, props) + val op = relOpScanIndexed(rex) + return rel(type, op) + } + + /** + * Output schema of an UNPIVOT is < k, v > + * + * @param rex + * @param k + * @param v + */ + private fun convertUnpivot(rex: Rex, k: Rel.Binding, v: Rel.Binding): Rel { + val schema = listOf(k, v) + val props = emptySet() + val type = relType(schema, props) + val op = relOpUnpivot(rex) + return rel(type, op) + } + + private fun convertSelectItem(item: SelectItem) = when (item) { + is SelectItem.Star -> convertSelectItemStar(item) + is SelectItem.Expr -> convertSelectItemExpr(item) + else -> TODO() // TODO ALAN + } + + private fun convertSelectItemStar(item: SelectItem.Star): Pair { + throw IllegalArgumentException("AST not normalized") + } + + private fun convertSelectItemExpr(item: SelectItem.Expr): Pair { + val name = when (val a = item.asAlias) { + null -> error("AST not normalized, missing AS alias on select item $item") + else -> a.symbol + } + val rex = V1RexConverter.apply(item.expr, env) + val binding = relBinding(name, rex.type) + return binding to rex + } + + /** + * Append [Rel.Op.Filter] only if a WHERE condition exists + */ + private fun convertWhere(input: Rel, expr: Expr?): Rel { + if (expr == null) { + return input + } + val type = input.type + val predicate = expr.toRex(env) + val op = relOpFilter(input, predicate) + return rel(type, op) + } + + /** + * Append [Rel.Op.Aggregate] only if SELECT contains aggregate expressions. + * + * TODO Set quantifiers + * TODO Group As + * + * @return Pair is returned where + * 1. Ast.Expr.SFW has every Ast.Expr.CallAgg replaced by a synthetic Ast.Expr.Var + * 2. Rel which has the appropriate Rex.Agg calls and groups + */ + @OptIn(PartiQLValueExperimental::class) + private fun convertAgg(input: Rel, select: QueryBody.SFW, groupBy: GroupBy?): Pair { + // Rewrite and extract all aggregations in the SELECT clause + val (sel, aggregations) = AggregationTransform.apply(select) + + // No aggregation planning required for GROUP BY + if (aggregations.isEmpty() && groupBy == null) { + return Pair(select, input) + } + + // Build the schema -> (calls... groups...) + val schema = mutableListOf() + val props = emptySet() + + // Build the rel operator + var strategy = Rel.Op.Aggregate.Strategy.FULL + val calls = aggregations.mapIndexed { i, expr -> + val binding = relBinding( + name = syntheticAgg(i), + type = (ANY), + ) + schema.add(binding) + val args = expr.args.map { arg -> arg.toRex(env) } + val id = V1AstToPlan.convert(expr.function) + if (id.hasQualifier()) { + error("Qualified aggregation calls are not supported.") + } + // lowercase normalize all calls + val name = id.getIdentifier().getText().lowercase() + if (name == "count" && expr.args.isEmpty()) { + relOpAggregateCallUnresolved( + name, + org.partiql.planner.internal.ir.SetQuantifier.ALL, + args = listOf(exprLit(int32Value(1)).toRex(env)) + ) + } else { + val setq = when (expr.setq?.code()) { + null -> org.partiql.planner.internal.ir.SetQuantifier.ALL + SetQuantifier.ALL -> org.partiql.planner.internal.ir.SetQuantifier.ALL + SetQuantifier.DISTINCT -> org.partiql.planner.internal.ir.SetQuantifier.DISTINCT + else -> TODO() // TODO ALAN + } + relOpAggregateCallUnresolved(name, setq, args) + } + }.toMutableList() + + // Add GROUP_AS aggregation + groupBy?.let { gb -> + gb.asAlias?.let { groupAs -> + val binding = relBinding(groupAs.symbol, ANY) + schema.add(binding) + val fields = input.type.schema.mapIndexed { bindingIndex, currBinding -> + rexOpStructField( + k = rex(STRING, rexOpLit(stringValue(currBinding.name))), + v = rex(ANY, rexOpVarLocal(0, bindingIndex)) + ) + } + val arg = listOf(rex(ANY, rexOpStruct(fields))) + calls.add(relOpAggregateCallUnresolved("group_as", org.partiql.planner.internal.ir.SetQuantifier.ALL, arg)) + } + } + var groups = emptyList() + if (groupBy != null) { + groups = groupBy.keys.map { + if (it.asAlias == null) { + error("not normalized, group key $it missing unique name") + } + val binding = relBinding( + name = it.asAlias!!.symbol, + type = (ANY) + ) + schema.add(binding) + it.expr.toRex(env) + } + strategy = when (groupBy.strategy.code()) { + GroupByStrategy.FULL -> Rel.Op.Aggregate.Strategy.FULL + GroupByStrategy.PARTIAL -> Rel.Op.Aggregate.Strategy.PARTIAL + else -> TODO() // TODO ALAN + } + } + val type = relType(schema, props) + val op = relOpAggregate(input, strategy, calls, groups) + val rel = rel(type, op) + return Pair(sel, rel) + } + + /** + * Append [Rel.Op.Filter] only if a HAVING condition exists + * + * Notes: + * - This currently does not support aggregation expressions in the WHERE condition + */ + private fun convertHaving(input: Rel, expr: Expr?): Rel { + if (expr == null) { + return input + } + val type = input.type + val predicate = expr.toRex(env) + val op = relOpFilter(input, predicate) + return rel(type, op) + } + + private fun visitIfQuerySet(expr: Expr): Rel { + return when (expr) { + is ExprQuerySet -> visit(expr, nil) + else -> { + val rex = V1RexConverter.applyRel(expr, env) + val op = relOpScan(rex) + val type = Rel.Type(listOf(Rel.Binding("_1", ANY)), props = emptySet()) + return rel(type, op) + } + } + } + + /** + * Append SQL set operator if present + */ + private fun convertSetOp(setExpr: QueryBody.SetOp): Rel { + val lhs = visitIfQuerySet(setExpr.lhs) + val rhs = visitIfQuerySet(setExpr.rhs) + val type = Rel.Type(listOf(Rel.Binding("_0", ANY)), props = emptySet()) + val quantifier = when (setExpr.type.setq?.code()) { + SetQuantifier.ALL -> org.partiql.planner.internal.ir.SetQuantifier.ALL + null, SetQuantifier.DISTINCT -> org.partiql.planner.internal.ir.SetQuantifier.DISTINCT + else -> TODO() // TODO ALAN + } + val outer = setExpr.isOuter + val op = when (setExpr.type.setOpType.code()) { + SetOpType.UNION -> Rel.Op.Union(quantifier, outer, lhs, rhs) + SetOpType.EXCEPT -> Rel.Op.Except(quantifier, outer, lhs, rhs) + SetOpType.INTERSECT -> Rel.Op.Intersect(quantifier, outer, lhs, rhs) + else -> TODO() // TODO ALAN + } + return rel(type, op) + } + + /** + * Append [Rel.Op.Sort] only if an ORDER BY clause is present + */ + private fun convertOrderBy(input: Rel, orderBy: OrderBy?): Rel { + if (orderBy == null) { + return input + } + val type = input.type.copy(props = setOf(Rel.Prop.ORDERED)) + val specs = orderBy.sorts.map { + val rex = it.expr.toRex(env) + val order = when (it.order?.code()) { + Order.DESC -> when (it.nulls?.code()) { + Nulls.LAST -> Rel.Op.Sort.Order.DESC_NULLS_LAST + Nulls.FIRST, null -> Rel.Op.Sort.Order.DESC_NULLS_FIRST + else -> TODO() // TODO ALAN + } + else -> when (it.nulls?.code()) { + Nulls.FIRST -> Rel.Op.Sort.Order.ASC_NULLS_FIRST + Nulls.LAST, null -> Rel.Op.Sort.Order.ASC_NULLS_LAST + else -> TODO() // TODO ALAN + } + } + relOpSortSpec(rex, order) + } + val op = relOpSort(input, specs) + return rel(type, op) + } + + /** + * Append [Rel.Op.Limit] if there is a LIMIT + */ + private fun convertLimit(input: Rel, limit: Expr?): Rel { + if (limit == null) { + return input + } + val type = input.type + val rex = V1RexConverter.apply(limit, env) + val op = relOpLimit(input, rex) + return rel(type, op) + } + + /** + * Append [Rel.Op.Offset] if there is an OFFSET + */ + private fun convertOffset(input: Rel, offset: Expr?): Rel { + if (offset == null) { + return input + } + val type = input.type + val rex = V1RexConverter.apply(offset, env) + val op = relOpOffset(input, rex) + return rel(type, op) + } + + private fun convertExclude(input: Rel, exclude: Exclude?): Rel { + if (exclude == null) { + return input + } + val type = input.type // PlanTyper handles typing the exclusion and removing redundant exclude paths + val paths = exclude.excludePaths + .groupBy(keySelector = { it.root }, valueTransform = { it.excludeSteps }) + .map { (root, exclusions) -> + val rootVar = (root.toRex(env)).op as Rex.Op.Var + val steps = exclusionsToSteps(exclusions) + relOpExcludePath(rootVar, steps) + } + val op = relOpExclude(input, paths) + return rel(type, op) + } + + private fun exclusionsToSteps(exclusions: List>): List { + if (exclusions.any { it.isEmpty() }) { + // if there exists a path with no further steps, can remove the longer paths + // e.g. t.a.b, t.a.b.c, t.a.b.d[*].*.e -> can keep just t.a.b + return emptyList() + } + return exclusions + .groupBy(keySelector = { it.first() }, valueTransform = { it.drop(1) }) + .map { (head, steps) -> + val type = stepToExcludeType(head) + val substeps = exclusionsToSteps(steps) + relOpExcludeStep(type, substeps) + } + } + + private fun stepToExcludeType(step: ExcludeStep): Rel.Op.Exclude.Type { + return when (step) { + is ExcludeStep.StructField -> { + when (step.symbol.isDelimited) { + false -> relOpExcludeTypeStructSymbol(step.symbol.symbol) + true -> relOpExcludeTypeStructKey(step.symbol.symbol) + } + } + is ExcludeStep.CollIndex -> relOpExcludeTypeCollIndex(step.index) + is ExcludeStep.StructWildcard -> relOpExcludeTypeStructWildcard() + is ExcludeStep.CollWildcard -> relOpExcludeTypeCollWildcard() + else -> TODO() // TODO ALAN + } + } + + // /** + // * Converts a GROUP AS X clause to a binding of the form: + // * ``` + // * { 'X': group_as({ 'a_0': e_0, ..., 'a_n': e_n }) } + // * ``` + // * + // * Notes: + // * - This was included to be consistent with the existing PartiqlAst and PartiqlLogical representations, + // * but perhaps we don't want to represent GROUP AS with an agg function. + // */ + // private fun convertGroupAs(name: String, from: From): Binding { + // val fields = from.bindings().map { n -> + // Plan.field( + // name = Plan.rexLit(ionString(n), STRING), + // value = Plan.rexId(n, Case.SENSITIVE, Rex.Id.Qualifier.UNQUALIFIED, type = STRUCT) + // ) + // } + // return Plan.binding( + // name = name, + // value = Plan.rexAgg( + // id = "group_as", + // args = listOf(Plan.rexTuple(fields, STRUCT)), + // modifier = Rex.Agg.Modifier.ALL, + // type = STRUCT + // ) + // ) + // } + } + + /** + * Rewrites a SELECT node replacing (and extracting) each aggregation `i` with a synthetic field name `$agg_i`. + */ + private object AggregationTransform : AstVisitor { + // currently hard-coded + @JvmStatic + private val aggregates = setOf("count", "avg", "sum", "min", "max", "any", "some", "every") + + private data class Context( + val aggregations: MutableList, + val keys: List + ) + + fun apply(node: QueryBody.SFW): Pair> { + val aggs = mutableListOf() + val keys = node.groupBy?.keys ?: emptyList() + val context = Context(aggs, keys) + val select = super.visitQueryBodySFW(node, context) as QueryBody.SFW + return Pair(select, aggs) + } + + override fun visitSelectValue(node: SelectValue, ctx: Context): AstNode { + val visited = super.visitSelectValue(node, ctx) + val substitutions = ctx.keys.associate { + it.expr to exprVarRef(identifierChain(identifier(it.asAlias!!.symbol, isDelimited = true), next = null), Scope.DEFAULT()) + } + return V1SubstitutionVisitor.visit(visited, substitutions) + } + + // only rewrite top-level SFW + override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: Context): AstNode = node + + override fun visitExprCall(node: ExprCall, ctx: Context) = + // TODO replace w/ proper function resolution to determine whether a function call is a scalar or aggregate. + // may require further modification of SPI interfaces to support + when (node.function.isAggregateCall()) { + true -> { + val id = identifierChain( + identifier( + symbol = syntheticAgg(ctx.aggregations.size), + isDelimited = false + ), + next = null + ) + ctx.aggregations += node + exprVarRef(id, Scope.DEFAULT()) + } + else -> node + } + + private fun String.isAggregateCall(): Boolean { + return aggregates.contains(this) + } + + private fun IdentifierChain.isAggregateCall(): Boolean { + return when (next) { + null -> root.symbol.lowercase().isAggregateCall() + else -> { + var curId = next + var last = curId + while (curId != null) { + last = curId + curId = curId.next + } + last!!.root.symbol.lowercase().isAggregateCall() + } + } + } + + override fun defaultReturn(node: AstNode, ctx: Context) = node + } + + private fun syntheticAgg(i: Int) = "\$agg_$i" + + private val ANY: CompilerType = CompilerType(PType.dynamic()) + private val BOOL: CompilerType = CompilerType(PType.bool()) + private val STRING: CompilerType = CompilerType(PType.string()) + private val STRUCT: CompilerType = CompilerType(PType.struct()) + private val BAG: CompilerType = CompilerType(PType.bag()) + private val LIST: CompilerType = CompilerType(PType.array()) + private val INT: CompilerType = CompilerType(PType.numeric()) +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt new file mode 100644 index 000000000..1188f136c --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt @@ -0,0 +1,1119 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package org.partiql.planner.internal.transforms + +import com.amazon.ionelement.api.loadSingleElement +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.DataType +import org.partiql.ast.v1.QueryBody +import org.partiql.ast.v1.SelectList +import org.partiql.ast.v1.SelectStar +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprAnd +import org.partiql.ast.v1.expr.ExprArray +import org.partiql.ast.v1.expr.ExprBag +import org.partiql.ast.v1.expr.ExprBetween +import org.partiql.ast.v1.expr.ExprCall +import org.partiql.ast.v1.expr.ExprCase +import org.partiql.ast.v1.expr.ExprCast +import org.partiql.ast.v1.expr.ExprCoalesce +import org.partiql.ast.v1.expr.ExprExtract +import org.partiql.ast.v1.expr.ExprInCollection +import org.partiql.ast.v1.expr.ExprIsType +import org.partiql.ast.v1.expr.ExprLike +import org.partiql.ast.v1.expr.ExprLit +import org.partiql.ast.v1.expr.ExprNot +import org.partiql.ast.v1.expr.ExprNullIf +import org.partiql.ast.v1.expr.ExprOperator +import org.partiql.ast.v1.expr.ExprOr +import org.partiql.ast.v1.expr.ExprOverlay +import org.partiql.ast.v1.expr.ExprPath +import org.partiql.ast.v1.expr.ExprPosition +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.ExprSessionAttribute +import org.partiql.ast.v1.expr.ExprStruct +import org.partiql.ast.v1.expr.ExprSubstring +import org.partiql.ast.v1.expr.ExprTrim +import org.partiql.ast.v1.expr.ExprVarRef +import org.partiql.ast.v1.expr.ExprVariant +import org.partiql.ast.v1.expr.PathStep +import org.partiql.ast.v1.expr.Scope +import org.partiql.ast.v1.expr.TrimSpec +import org.partiql.errors.TypeCheckException +import org.partiql.planner.internal.Env +import org.partiql.planner.internal.ir.Rel +import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.builder.plan +import org.partiql.planner.internal.ir.rel +import org.partiql.planner.internal.ir.relBinding +import org.partiql.planner.internal.ir.relOpJoin +import org.partiql.planner.internal.ir.relOpScan +import org.partiql.planner.internal.ir.relOpUnpivot +import org.partiql.planner.internal.ir.relType +import org.partiql.planner.internal.ir.rex +import org.partiql.planner.internal.ir.rexOpCallUnresolved +import org.partiql.planner.internal.ir.rexOpCastUnresolved +import org.partiql.planner.internal.ir.rexOpCoalesce +import org.partiql.planner.internal.ir.rexOpCollection +import org.partiql.planner.internal.ir.rexOpLit +import org.partiql.planner.internal.ir.rexOpNullif +import org.partiql.planner.internal.ir.rexOpPathIndex +import org.partiql.planner.internal.ir.rexOpPathKey +import org.partiql.planner.internal.ir.rexOpPathSymbol +import org.partiql.planner.internal.ir.rexOpSelect +import org.partiql.planner.internal.ir.rexOpStruct +import org.partiql.planner.internal.ir.rexOpStructField +import org.partiql.planner.internal.ir.rexOpSubquery +import org.partiql.planner.internal.ir.rexOpTupleUnion +import org.partiql.planner.internal.ir.rexOpVarLocal +import org.partiql.planner.internal.ir.rexOpVarUnresolved +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType +import org.partiql.spi.catalog.Identifier +import org.partiql.types.PType +import org.partiql.value.MissingValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.StringValue +import org.partiql.value.boolValue +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.io.PartiQLValueIonReaderBuilder +import org.partiql.value.nullValue +import org.partiql.value.stringValue +import org.partiql.ast.v1.SetQuantifier as AstSetQuantifier + +/** + * Converts an AST expression node to a Plan Rex node; ignoring any typing. + */ +internal object V1RexConverter { + + internal fun apply(expr: Expr, context: Env): Rex = ToRex.visitExprCoerce(expr, context) + + internal fun applyRel(expr: Expr, context: Env): Rex = expr.accept(ToRex, context) + + @OptIn(PartiQLValueExperimental::class) + @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") + private object ToRex : AstVisitor { + + private val COLL_AGG_NAMES = setOf( + "coll_any", + "coll_avg", + "coll_count", + "coll_every", + "coll_max", + "coll_min", + "coll_some", + "coll_sum", + ) + + override fun defaultReturn(node: AstNode, context: Env): Rex = + throw IllegalArgumentException("unsupported rex $node") + + override fun visitExprLit(node: ExprLit, context: Env): Rex { + val type = CompilerType( + _delegate = node.value.type.toPType(), + isNullValue = node.value.isNull, + isMissingValue = node.value is MissingValue + ) + val op = rexOpLit(node.value) + return rex(type, op) + } + + /** + * TODO PartiQLValue will be replaced by Datum (i.e. IonDatum) is a subsequent PR. + */ + override fun visitExprVariant(node: ExprVariant, ctx: Env): Rex { + if (node.encoding != "ion") { + throw IllegalArgumentException("unsupported encoding ${node.encoding}") + } + val ion = loadSingleElement(node.value) + val value = PartiQLValueIonReaderBuilder.standard().build(ion).read() + val type = CompilerType(value.type.toPType()) + return rex(type, rexOpLit(value)) + } + + /** + * !! IMPORTANT !! + * + * This is the top-level visit for handling subquery coercion. The default behavior is to coerce to a scalar. + * In some situations, ie comparison to complex types we may make assertions on the desired type. + * + * It is recommended that every method (except for the exceptional cases) recurse the tree from visitExprCoerce. + * + * - RHS of comparison when LHS is an array or collection expression; and visa-versa + * - It is the collection expression of a FROM clause or JOIN + * - It is the RHS of an IN predicate + * - It is an argument of an OUTER set operator. + * + * @param node + * @param ctx + * @return + */ + internal fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex { + val rex = super.visitExpr(node, ctx) + return when (isSqlSelect(node)) { + true -> { + val select = rex.op as Rex.Op.Select + rex( + CompilerType(PType.dynamic()), + rexOpSubquery( + constructor = select.constructor, + rel = select.rel, + coercion = coercion + ) + ) + } + false -> rex + } + } + + override fun visitExprVarRef(node: ExprVarRef, context: Env): Rex { + val type = (ANY) + val identifier = V1AstToPlan.convert(node.identifierChain) + val scope = when (node.scope.code()) { + Scope.DEFAULT -> Rex.Op.Var.Scope.DEFAULT + Scope.LOCAL -> Rex.Op.Var.Scope.LOCAL + else -> TODO() // TODO ALAN + } + val op = rexOpVarUnresolved(identifier, scope) + return rex(type, op) + } + + private fun resolveUnaryOp(symbol: String, rhs: Expr, context: Env): Rex { + val type = (ANY) + // Args + val arg = visitExprCoerce(rhs, context) + val args = listOf(arg) + // Fn + val name = when (symbol) { + // TODO move hard-coded operator resolution into SPI + "+" -> "pos" + "-" -> "neg" + else -> error("unsupported unary op $symbol") + } + val id = Identifier.delimited(name) + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + private fun resolveBinaryOp(lhs: Expr, symbol: String, rhs: Expr, context: Env): Rex { + val type = (ANY) + val args = when (symbol) { + "<", ">", + "<=", ">=", + "=", "<>", "!=" -> { + when { + // Example: [1, 2] < (SELECT a, b FROM t) + isLiteralArray(lhs) && isSqlSelect(rhs) -> { + val l = visitExprCoerce(lhs, context) + val r = visitExprCoerce(rhs, context, Rex.Op.Subquery.Coercion.ROW) + listOf(l, r) + } + // Example: (SELECT a, b FROM t) < [1, 2] + isSqlSelect(lhs) && isLiteralArray(rhs) -> { + val l = visitExprCoerce(lhs, context, Rex.Op.Subquery.Coercion.ROW) + val r = visitExprCoerce(rhs, context) + listOf(l, r) + } + // Example: 1 < 2 + else -> { + val l = visitExprCoerce(lhs, context) + val r = visitExprCoerce(rhs, context) + listOf(l, r) + } + } + } + // Example: 1 + 2 + else -> { + val l = visitExprCoerce(lhs, context) + val r = visitExprCoerce(rhs, context) + listOf(l, r) + } + } + // Wrap if a NOT, if necessary + return when (symbol) { + "<>", "!=" -> { + val op = negate(call("eq", *args.toTypedArray())) + rex(type, op) + } + else -> { + val name = when (symbol) { + // TODO eventually move hard-coded operator resolution into SPI + "<" -> "lt" + ">" -> "gt" + "<=" -> "lte" + ">=" -> "gte" + "=" -> "eq" + "||" -> "concat" + "+" -> "plus" + "-" -> "minus" + "*" -> "times" + "/" -> "divide" + "%" -> "modulo" + "&" -> "bitwise_and" + else -> error("unsupported binary op $symbol") + } + val id = Identifier.delimited(name) + val op = rexOpCallUnresolved(id, args) + rex(type, op) + } + } + } + + override fun visitExprOperator(node: ExprOperator, ctx: Env): Rex { + val lhs = node.lhs + return if (lhs != null) { + resolveBinaryOp(lhs, node.symbol, node.rhs, ctx) + } else { + resolveUnaryOp(node.symbol, node.rhs, ctx) + } + } + + override fun visitExprNot(node: ExprNot, ctx: Env): Rex { + val type = (ANY) + // Args + val arg = visitExprCoerce(node.value, ctx) + val args = listOf(arg) + // Fn + val id = Identifier.delimited("not") + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + override fun visitExprAnd(node: ExprAnd, ctx: Env): Rex { + val type = (ANY) + val l = visitExprCoerce(node.lhs, ctx) + val r = visitExprCoerce(node.rhs, ctx) + val args = listOf(l, r) + + // Wrap if a NOT, if necessary + val id = Identifier.delimited("and") + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + override fun visitExprOr(node: ExprOr, ctx: Env): Rex { + val type = (ANY) + val l = visitExprCoerce(node.lhs, ctx) + val r = visitExprCoerce(node.rhs, ctx) + val args = listOf(l, r) + + // Wrap if a NOT, if necessary + val id = Identifier.delimited("or") + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + private fun isLiteralArray(node: Expr): Boolean = node is ExprArray + + private fun isSqlSelect(node: Expr): Boolean { + return if (node is ExprQuerySet) { + val body = node.body + body is QueryBody.SFW && (body.select is SelectList || body.select is SelectStar) + } else { + false + } + } + + override fun visitExprPath(node: ExprPath, context: Env): Rex { + // Args + val root = visitExprCoerce(node.root, context) + + // Attempt to create qualified identifier + val (newRoot, newSteps) = when (val op = root.op) { + is Rex.Op.Var.Unresolved -> { + // convert consecutive symbol path steps to the root identifier + var i = 0 + val parts = mutableListOf() + parts.addAll(op.identifier.getParts()) + val newSteps = mutableListOf() + var curStep = node.next + while (curStep != null) { + if (curStep !is PathStep.Field) { + break + } + parts.add(V1AstToPlan.part(curStep.field)) + newSteps.add(curStep) + i += 1 + curStep = curStep.next + } + val newRoot = rex(ANY, rexOpVarUnresolved(Identifier.of(parts), op.scope)) + newRoot to newSteps + } + else -> { + val allSteps = mutableListOf() + var curStep = node.next + while (curStep != null) { + allSteps.add(curStep) + curStep = curStep.next + } + root to allSteps + } + } + + if (newSteps.isEmpty()) { + return newRoot + } + + val fromList = mutableListOf() + + var varRefIndex = 0 // tracking var ref index + + val pathNavi = newSteps.fold(newRoot) { current, step -> + val path = when (step) { + is PathStep.Element -> { + val key = visitExprCoerce(step.element, context) + val op = when (val astKey = step.element) { + is ExprLit -> when (astKey.value) { + is StringValue -> rexOpPathKey(current, key) + else -> rexOpPathIndex(current, key) + } + is ExprCast -> when (astKey.asType.code() == DataType.STRING) { + true -> rexOpPathKey(current, key) + false -> rexOpPathIndex(current, key) + } + else -> rexOpPathIndex(current, key) + } + op + } + + is PathStep.Field -> { + when (step.field.isDelimited) { + true -> { + // case-sensitive path step becomes a key lookup + rexOpPathKey(current, rexString(step.field.symbol)) + } + false -> { + // case-insensitive path step becomes a symbol lookup + rexOpPathSymbol(current, step.field.symbol) + } + } + } + + // Unpivot and Wildcard steps trigger the rewrite + // According to spec Section 4.3 + // ew1p1...wnpn + // rewrite to: + // SELECT VALUE v_n.p_n + // FROM + // u_1 e as v_1 + // u_2 @v_1.p_1 as v_2 + // ... + // u_n @v_(n-1).p_(n-1) as v_n + // The From clause needs to be rewritten to + // Join <------------------- schema: [(k_1), v_1, (k_2), v_2, ..., (k_(n-1)) v_(n-1)] + // / \ + // ... un @v_(n-1).p_(n-1) <-- stack: [global, typeEnv: [outer: [global], schema: [(k_1), v_1, (k_2), v_2, ..., (k_(n-1)) v_(n-1)]]] + // Join <----------------------- schema: [(k_1), v_1, (k_2), v_2, (k_3), v_3] + // / \ + // u_2 @v_1.p_1 as v2 <------- stack: [global, typeEnv: [outer: [global], schema: [(k_1), v_1, (k_2), v_2]]] + // JOIN <---------------------------- schema: [(k_1), v_1, (k_2), v_2] + // / \ + // u_1 e as v_1 < ----\----------------------- stack: [global] + // u_2 @v_1.p_1 as v2 <------ stack: [global, typeEnv: [outer: [global], schema: [(k_1), v_1]]] + // while doing the traversal, instead of passing the stack, + // each join will produce its own schema and pass the schema as a type Env. + // The (k_i) indicate the possible key binding produced by unpivot. + // We calculate the var ref on the fly. + is PathStep.AllFields -> { + // Unpivot produces two binding, in this context we want the value, + // which always going to be the second binding + val op = rexOpVarLocal(1, varRefIndex + 1) + varRefIndex += 2 + val index = fromList.size + fromList.add(relFromUnpivot(current, index)) + op + } + is PathStep.AllElements -> { + // Scan produce only one binding + val op = rexOpVarLocal(1, varRefIndex) + varRefIndex += 1 + val index = fromList.size + fromList.add(relFromDefault(current, index)) + op + } + else -> TODO() // TODO ALAN + } + rex(ANY, path) + } + + if (fromList.size == 0) return pathNavi + val fromNode = fromList.reduce { acc, scan -> + val schema = acc.type.schema + scan.type.schema + val props = emptySet() + val type = relType(schema, props) + rel(type, relOpJoin(acc, scan, rex(BOOL, rexOpLit(boolValue(true))), Rel.Op.Join.Type.INNER)) + } + + // compute the ref used by select construct + // always going to be the last binding + val selectRef = fromNode.type.schema.size - 1 + + val constructor = when (val op = pathNavi.op) { + is Rex.Op.Path.Index -> rex(pathNavi.type, rexOpPathIndex(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) + is Rex.Op.Path.Key -> rex(pathNavi.type, rexOpPathKey(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) + is Rex.Op.Path.Symbol -> rex(pathNavi.type, rexOpPathSymbol(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) + is Rex.Op.Var.Local -> rex(pathNavi.type, rexOpVarLocal(0, selectRef)) + else -> throw IllegalStateException() + } + val op = rexOpSelect(constructor, fromNode) + return rex(ANY, op) + } + + /** + * Construct Rel(Scan([path])). + * + * The constructed rel would produce one binding: _v$[index] + */ + private fun relFromDefault(path: Rex, index: Int): Rel { + val schema = listOf( + relBinding( + name = "_v$index", // fresh variable + type = path.type + ) + ) + val props = emptySet() + val relType = relType(schema, props) + return rel(relType, relOpScan(path)) + } + + /** + * Construct Rel(Unpivot([path])). + * + * The constructed rel would produce two bindings: _k$[index] and _v$[index] + */ + private fun relFromUnpivot(path: Rex, index: Int): Rel { + val schema = listOf( + relBinding( + name = "_k$index", // fresh variable + type = STRING + ), + relBinding( + name = "_v$index", // fresh variable + type = path.type + ) + ) + val props = emptySet() + val relType = relType(schema, props) + return rel(relType, relOpUnpivot(path)) + } + + private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) + + override fun visitExprCall(node: ExprCall, context: Env): Rex { + val type = (ANY) + // Fn + val id = V1AstToPlan.convert(node.function) + if (id.hasQualifier()) { + error("Qualified function calls are not currently supported.") + } + if (id.matches("TUPLEUNION")) { + return visitExprCallTupleUnion(node, context) + } + if (id.matches("EXISTS", ignoreCase = true)) { + return visitExprCallExists(node, context) + } + // Args + val args = node.args.map { visitExprCoerce(it, context) } + + // Check if function is actually coll_ + if (isCollAgg(node)) { + return callToCollAgg(id, node.setq, args) + } + + if (node.setq != null) { + error("Currently, only COLL_ may use set quantifiers.") + } + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + /** + * @return whether call is `COLL_`. + */ + private fun isCollAgg(node: ExprCall): Boolean { + val fn = node.function + val id = if (fn.next == null) { + // is not a qualified identifier chain + node.function.root + } else { + return false + } + return COLL_AGG_NAMES.contains(id.symbol.lowercase()) + } + + /** + * Converts COLL_ to the relevant function calls. For example: + * - `COLL_SUM(x)` becomes `coll_sum_all(x)` + * - `COLL_SUM(ALL x)` becomes `coll_sum_all(x)` + * - `COLL_SUM(DISTINCT x)` becomes `coll_sum_distinct(x)` + * + * It is assumed that the [id] has already been vetted by [isCollAgg]. + */ + private fun callToCollAgg(id: Identifier, setQuantifier: AstSetQuantifier?, args: List): Rex { + if (id.hasQualifier()) { + error("Qualified function calls are not currently supported.") + } + if (args.size != 1) { + error("Aggregate calls currently only support single arguments. Received ${args.size} arguments.") + } + val postfix = when (setQuantifier?.code()) { + AstSetQuantifier.DISTINCT -> "_distinct" + AstSetQuantifier.ALL -> "_all" + null -> "_all" + else -> TODO() // TODO ALAN + } + val newId = Identifier.regular(id.getIdentifier().getText() + postfix) + val op = Rex.Op.Call.Unresolved(newId, listOf(args[0])) + return Rex(ANY, op) + } + + private fun visitExprCallTupleUnion(node: ExprCall, context: Env): Rex { + val type = (STRUCT) + val args = node.args.map { visitExprCoerce(it, context) }.toMutableList() + val op = rexOpTupleUnion(args) + return rex(type, op) + } + + /** + * Assume that the node's identifier refers to EXISTS. + * TODO: This could be better suited as a dedicated node in the future. + */ + private fun visitExprCallExists(node: ExprCall, context: Env): Rex { + val type = (BOOL) + if (node.args.size != 1) { + error("EXISTS requires a single argument.") + } + val arg = visitExpr(node.args[0], context) + val op = rexOpCallUnresolved(V1AstToPlan.convert(node.function), listOf(arg)) + return rex(type, op) + } + + override fun visitExprCase(node: ExprCase, context: Env) = plan { + val type = (ANY) + val rex = when (node.expr) { + null -> null + else -> visitExprCoerce(node.expr!!, context) // match `rex + } + + // Converts AST CASE (x) WHEN y THEN z --> Plan CASE WHEN x = y THEN z + val id = Identifier.delimited("eq") + val createBranch: (Rex, Rex) -> Rex.Op.Case.Branch = { condition: Rex, result: Rex -> + val updatedCondition = when (rex) { + null -> condition + else -> rex(type, rexOpCallUnresolved(id, listOf(rex, condition))) + } + rexOpCaseBranch(updatedCondition, result) + } + + val branches = node.branches.map { + val branchCondition = visitExprCoerce(it.condition, context) + val branchRex = visitExprCoerce(it.expr, context) + createBranch(branchCondition, branchRex) + }.toMutableList() + + val defaultRex = when (val default = node.defaultExpr) { + null -> rex(type = ANY, op = rexOpLit(value = nullValue())) + else -> visitExprCoerce(default, context) + } + val op = rexOpCase(branches = branches, default = defaultRex) + rex(type, op) + } + + override fun visitExprArray(node: ExprArray, ctx: Env): Rex { + val values = node.values.map { visitExprCoerce(it, ctx) } + val op = rexOpCollection(values) + return rex(LIST, op) + } + + override fun visitExprBag(node: ExprBag, ctx: Env): Rex { + val values = node.values.map { visitExprCoerce(it, ctx) } + val op = rexOpCollection(values) + return rex(BAG, op) + } + + override fun visitExprStruct(node: ExprStruct, context: Env): Rex { + val type = (STRUCT) + val fields = node.fields.map { + val k = visitExprCoerce(it.name, context) + val v = visitExprCoerce(it.value, context) + rexOpStructField(k, v) + } + val op = rexOpStruct(fields) + return rex(type, op) + } + + // SPECIAL FORMS + + /** + * NOT? LIKE ( ESCAPE )? + */ + override fun visitExprLike(node: ExprLike, ctx: Env): Rex { + val type = BOOL + // Args + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = visitExprCoerce(node.pattern, ctx) + val arg2 = node.escape?.let { visitExprCoerce(it, ctx) } + // Call Variants + var call = when (arg2) { + null -> call("like", arg0, arg1) + else -> call("like_escape", arg0, arg1, arg2) + } + // NOT? + if (node.not == true) { + call = negate(call) + } + return rex(type, call) + } + + /** + * NOT? BETWEEN AND + */ + override fun visitExprBetween(node: ExprBetween, ctx: Env): Rex = plan { + val type = BOOL + // Args + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = visitExprCoerce(node.from, ctx) + val arg2 = visitExprCoerce(node.to, ctx) + // Call + var call = call("between", arg0, arg1, arg2) + // NOT? + if (node.not == true) { + call = negate(call) + } + rex(type, call) + } + + /** + * NOT? IN + * + * SQL Spec 1999 section 8.4 + * RVC IN IPV is equivalent to RVC = ANY IPV -> Quantified Comparison Predicate + * Which means: + * Let the expression be T in C, where C is [a1, ..., an] + * T in C is true iff T = a_x is true for any a_x in [a1, ...., an] + * T in C is false iff T = a_x is false for every a_x in [a1, ....., an ] or cardinality of the collection is 0. + * Otherwise, T in C is unknown. + * + */ + override fun visitExprInCollection(node: ExprInCollection, ctx: Env): Rex { + val type = BOOL + // Args + val arg0 = visitExprCoerce(node.lhs, ctx) + val arg1 = visitExpr(node.rhs, ctx) // !! don't insert scalar subquery coercions + + // Call + var call = call("in_collection", arg0, arg1) + // NOT? + if (node.not == true) { + call = negate(call) + } + return rex(type, call) + } + + /** + * IS ? + */ + override fun visitExprIsType(node: ExprIsType, ctx: Env): Rex { + val type = BOOL + // arg + val arg0 = visitExprCoerce(node.value, ctx) + val targetType = node.type + var call = when (targetType.code()) { + // + DataType.NULL -> call("is_null", arg0) + DataType.MISSING -> call("is_missing", arg0) + // + // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT + DataType.CHARACTER, DataType.CHAR -> call("is_char", targetType.length.toRex(), arg0) + DataType.CHARACTER_VARYING, DataType.VARCHAR -> call("is_varchar", targetType.length.toRex(), arg0) + DataType.CLOB -> call("is_clob", arg0) + DataType.STRING -> call("is_string", targetType.length.toRex(), arg0) + DataType.SYMBOL -> call("is_symbol", arg0) + // + // TODO BINARY_LARGE_OBJECT + DataType.BLOB -> call("is_blob", arg0) + // + DataType.BIT -> call("is_bit", arg0) // TODO define in parser + DataType.BIT_VARYING -> call("is_bitVarying", arg0) // TODO define in parser + // - + DataType.NUMERIC -> call("is_numeric", targetType.precision.toRex(), targetType.scale.toRex(), arg0) + DataType.DEC, DataType.DECIMAL -> call("is_decimal", targetType.precision.toRex(), targetType.scale.toRex(), arg0) + DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> call("is_int64", arg0) + DataType.INT4, DataType.INTEGER4, DataType.INTEGER -> call("is_int32", arg0) + DataType.INT -> call("is_int", arg0) // TODO ALAN figure out if INT should map to INT4 + DataType.INT2, DataType.SMALLINT -> call("is_int16", arg0) + DataType.TINYINT -> call("is_int8", arg0) // TODO define in parser + // - + DataType.FLOAT -> call("is_float32", arg0) + DataType.REAL -> call("is_real", arg0) + DataType.DOUBLE_PRECISION -> call("is_float64", arg0) + // + DataType.BOOLEAN, DataType.BOOL -> call("is_bool", arg0) + // + DataType.DATE -> call("is_date", arg0) + // TODO: DO we want to seperate with time zone vs without time zone into two different type in the plan? + // leave the parameterized type out for now until the above is answered + DataType.TIME -> call("is_time", arg0) + DataType.TIME_WITH_TIME_ZONE -> call("is_timeWithTz", arg0) + DataType.TIMESTAMP -> call("is_timestamp", arg0) + DataType.TIMESTAMP_WITH_TIME_ZONE -> call("is_timestampWithTz", arg0) + // + DataType.INTERVAL -> call("is_interval", arg0) // TODO define in parser + // + DataType.STRUCT, DataType.TUPLE -> call("is_struct", arg0) + // + DataType.LIST -> call("is_list", arg0) + DataType.BAG -> call("is_bag", arg0) + DataType.SEXP -> call("is_sexp", arg0) + // + DataType.USER_DEFINED -> call("is_custom", arg0) + // TODO ALAN other types that are in functions but aren't in AST yet +// DataType.BYTE_STRING -> call("is_byteString", arg0) +// DataType.ANY -> call("is_any", arg0) + else -> TODO() // TODO ALAN + } + + if (node.not == true) { + call = negate(call) + } + + return rex(type, call) + } + + override fun visitExprCoalesce(node: ExprCoalesce, ctx: Env): Rex { + val type = ANY + val args = node.args.map { arg -> + visitExprCoerce(arg, ctx) + } + val op = rexOpCoalesce(args) + return rex(type, op) + } + + override fun visitExprNullIf(node: ExprNullIf, ctx: Env): Rex { + val type = ANY + val v1 = visitExprCoerce(node.v1, ctx) + val v2 = visitExprCoerce(node.v2, ctx) + val op = rexOpNullif(v1, v2) + return rex(type, op) + } + + /** + * SUBSTRING( (FROM (FOR )?)? ) + */ + override fun visitExprSubstring(node: ExprSubstring, ctx: Env): Rex { + val type = ANY + // Args + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = node.start?.let { visitExprCoerce(it, ctx) } ?: rex(INT, rexOpLit(int64Value(1))) + val arg2 = node.length?.let { visitExprCoerce(it, ctx) } + // Call Variants + val call = when (arg2) { + null -> call("substring", arg0, arg1) + else -> call("substring", arg0, arg1, arg2) + } + return rex(type, call) + } + + /** + * POSITION( IN ) + */ + override fun visitExprPosition(node: ExprPosition, ctx: Env): Rex { + val type = ANY + // Args + val arg0 = visitExprCoerce(node.lhs, ctx) + val arg1 = visitExprCoerce(node.rhs, ctx) + // Call + val call = call("position", arg0, arg1) + return rex(type, call) + } + + /** + * TRIM([LEADING|TRAILING|BOTH]? ( FROM)? ) + */ + override fun visitExprTrim(node: ExprTrim, ctx: Env): Rex { + val type = STRING + // Args + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = node.chars?.let { visitExprCoerce(it, ctx) } + // Call Variants + val call = when (node.trimSpec?.code()) { + TrimSpec.LEADING -> when (arg1) { + null -> call("trim_leading", arg0) + else -> call("trim_leading_chars", arg0, arg1) + } + TrimSpec.TRAILING -> when (arg1) { + null -> call("trim_trailing", arg0) + else -> call("trim_trailing_chars", arg0, arg1) + } + // TODO: We may want to add a trim_both for trim(BOTH FROM arg) + else -> when (arg1) { + null -> call("trim", arg0) + else -> call("trim_chars", arg0, arg1) + } + } + return rex(type, call) + } + + /** + * SQL Spec 1999: Section 6.18 + * + * ::= + * OVERLAY + * PLACING + * FROM + * [ FOR ] + * + * The is equivalent to: + * + * SUBSTRING ( CV FROM 1 FOR SP - 1 ) || RS || SUBSTRING ( CV FROM SP + SL ) + * + * Where CV is the first , + * SP is the + * RS is the second , + * SL is the if specified, otherwise it is char_length(RS). + */ + override fun visitExprOverlay(node: ExprOverlay, ctx: Env): Rex { + val cv = visitExprCoerce(node.value, ctx) + val sp = visitExprCoerce(node.placing, ctx) + val rs = visitExprCoerce(node.from, ctx) + val sl = node.forLength?.let { visitExprCoerce(it, ctx) } ?: rex(ANY, call("char_length", rs)) + val p1 = rex( + ANY, + call( + "substring", + cv, + rex(INT4, rexOpLit(int32Value(1))), + rex(ANY, call("minus", sp, rex(INT4, rexOpLit(int32Value(1))))) + ) + ) + val p2 = rex(ANY, call("concat", p1, rs)) + return rex( + ANY, + call( + "concat", + p2, + rex(ANY, call("substring", cv, rex(ANY, call("plus", sp, sl)))) + ) + ) + } + + override fun visitExprExtract(node: ExprExtract, ctx: Env): Rex { + val call = call("extract_${node.field.name().lowercase()}", visitExprCoerce(node.source, ctx)) + return rex(ANY, call) + } + + override fun visitExprCast(node: ExprCast, ctx: Env): Rex { + val type = visitType(node.asType) + val arg = visitExprCoerce(node.value, ctx) + return rex(ANY, rexOpCastUnresolved(type, arg)) + } + + private fun visitType(type: DataType): CompilerType { + return when (type.code()) { + // + DataType.NULL -> error("Casting to NULL is not supported.") + DataType.MISSING -> error("Casting to MISSING is not supported.") + // + // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT + DataType.CHARACTER, DataType.CHAR -> { + val length = type.length ?: 1 + assertGtZeroAndCreate(PType.Kind.CHAR, "length", length, PType::character) + } + DataType.CHARACTER_VARYING, DataType.VARCHAR -> { + val length = type.length ?: 1 + assertGtZeroAndCreate(PType.Kind.VARCHAR, "length", length, PType::varchar) + } + DataType.CLOB -> assertGtZeroAndCreate(PType.Kind.CLOB, "length", type.length ?: Int.MAX_VALUE, PType::clob) + DataType.STRING -> PType.string() + DataType.SYMBOL -> PType.symbol() + // + // TODO BINARY_LARGE_OBJECT + DataType.BLOB -> assertGtZeroAndCreate(PType.Kind.BLOB, "length", type.length ?: Int.MAX_VALUE, PType::blob) + // + DataType.BIT -> error("BIT is not supported yet.") + DataType.BIT_VARYING -> error("BIT VARYING is not supported yet.") + // - + DataType.NUMERIC -> { + val p = type.precision + val s = type.scale + when { + p == null && s == null -> PType.decimal() + p != null && s != null -> { + assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) + assertParamCompToZero(PType.Kind.NUMERIC, "scale", s, true) + if (s > p) { + throw TypeCheckException("Numeric scale cannot be greater than precision.") + } + PType.decimal(type.precision!!, type.scale!!) + } + p != null && s == null -> { + assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) + PType.decimal(p, 0) + } + else -> error("Precision can never be null while scale is specified.") + } + } + DataType.DEC, DataType.DECIMAL -> { + val p = type.precision + val s = type.scale + when { + p == null && s == null -> PType.decimal() + p != null && s != null -> { + assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) + assertParamCompToZero(PType.Kind.DECIMAL, "scale", s, true) + if (s > p) { + throw TypeCheckException("Decimal scale cannot be greater than precision.") + } + PType.decimal(p, s) + } + p != null && s == null -> { + assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) + PType.decimal(p, 0) + } + else -> error("Precision can never be null while scale is specified.") + } + } + DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> PType.bigint() + DataType.INT4, DataType.INTEGER4, DataType.INTEGER -> PType.integer() + DataType.INT -> PType.numeric() // TODO ALAN figure out if INT should map to INT4 + DataType.INT2, DataType.SMALLINT -> PType.smallint() + DataType.TINYINT -> PType.tinyint() // TODO define in parser + // - + DataType.FLOAT -> PType.real() + DataType.REAL -> PType.real() + DataType.DOUBLE_PRECISION -> PType.doublePrecision() + // + DataType.BOOL -> PType.bool() + // + DataType.DATE -> PType.date() + DataType.TIME -> assertGtEqZeroAndCreate(PType.Kind.TIME, "precision", type.precision ?: 0, PType::time) + DataType.TIME_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.Kind.TIMEZ, "precision", type.precision ?: 0, PType::timez) + DataType.TIMESTAMP -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMP, "precision", type.precision ?: 6, PType::timestamp) + DataType.TIMESTAMP_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMPZ, "precision", type.precision ?: 6, PType::timestampz) + // + DataType.INTERVAL -> error("INTERVAL is not supported yet.") + // + DataType.STRUCT -> PType.struct() + DataType.TUPLE -> PType.struct() + // + DataType.LIST -> PType.array() + DataType.BAG -> PType.bag() + DataType.SEXP -> PType.sexp() + // + DataType.USER_DEFINED -> TODO("Custom type not supported ") + // TODO ALAN other types that are in previous RexConverter but not in AST yet +// DataType.ByteString -> error("BINARY is not supported yet.") +// DataType.Any -> PType.dynamic() + else -> TODO() // TODO ALAN + }.toCType() + } + + private fun assertGtZeroAndCreate(type: PType.Kind, param: String, value: Int, create: (Int) -> PType): PType { + assertParamCompToZero(type, param, value, false) + return create.invoke(value) + } + + private fun assertGtEqZeroAndCreate(type: PType.Kind, param: String, value: Int, create: (Int) -> PType): PType { + assertParamCompToZero(type, param, value, true) + return create.invoke(value) + } + + /** + * @param allowZero when FALSE, this asserts that [value] > 0. If TRUE, this asserts that [value] >= 0. + */ + private fun assertParamCompToZero(type: PType.Kind, param: String, value: Int, allowZero: Boolean) { + val (result, compString) = when (allowZero) { + true -> (value >= 0) to "greater than" + false -> (value > 0) to "greater than or equal to" + } + if (!result) { + throw TypeCheckException("$type $param must be an integer value $compString 0.") + } + } + + // TODO ALAN add add the same validation here or in parser. Ensure the DATE_ADD/DATE_DIFF fn names map to the + // same header functions. +// override fun visitExprDateAdd(node: Expr.DateAdd, ctx: Env): Rex { +// val type = TIMESTAMP +// // Args +// val arg0 = visitExprCoerce(node.lhs, ctx) +// val arg1 = visitExprCoerce(node.rhs, ctx) +// // Call Variants +// val call = when (node.field) { +// DatetimeField.TIMEZONE_HOUR -> error("Invalid call DATE_ADD(TIMEZONE_HOUR, ...)") +// DatetimeField.TIMEZONE_MINUTE -> error("Invalid call DATE_ADD(TIMEZONE_MINUTE, ...)") +// else -> call("date_add_${node.field.name.lowercase()}", arg0, arg1) +// } +// return rex(type, call) +// } +// +// override fun visitExprDateDiff(node: Expr.DateDiff, ctx: Env): Rex { +// val type = TIMESTAMP +// // Args +// val arg0 = visitExprCoerce(node.lhs, ctx) +// val arg1 = visitExprCoerce(node.rhs, ctx) +// // Call Variants +// val call = when (node.field) { +// DatetimeField.TIMEZONE_HOUR -> error("Invalid call DATE_DIFF(TIMEZONE_HOUR, ...)") +// DatetimeField.TIMEZONE_MINUTE -> error("Invalid call DATE_DIFF(TIMEZONE_MINUTE, ...)") +// else -> call("date_diff_${node.field.name.lowercase()}", arg0, arg1) +// } +// return rex(type, call) +// } + + override fun visitExprSessionAttribute(node: ExprSessionAttribute, ctx: Env): Rex { + val type = ANY + val fn = node.sessionAttribute.name().lowercase() + val call = call(fn) + return rex(type, call) + } + + override fun visitExprQuerySet(node: ExprQuerySet, context: Env): Rex = V1RelConverter.apply(node, context) + + // Helpers + + private fun negate(call: Rex.Op): Rex.Op.Call { + val id = Identifier.delimited("not") + val arg = rex(BOOL, call) + return rexOpCallUnresolved(id, listOf(arg)) + } + + /** + * Create a [Rex.Op.Call.Static] node which has a hidden unresolved Function. + * The purpose of having such hidden function is to prevent usage of generated function name in query text. + */ + private fun call(name: String, vararg args: Rex): Rex.Op.Call { + val id = Identifier.regular(name) + return rexOpCallUnresolved(id, args.toList()) + } + + private fun Int?.toRex() = rex(INT4, rexOpLit(int32Value(this))) + + private val ANY: CompilerType = CompilerType(PType.dynamic()) + private val BOOL: CompilerType = CompilerType(PType.bool()) + private val STRING: CompilerType = CompilerType(PType.string()) + private val STRUCT: CompilerType = CompilerType(PType.struct()) + private val BAG: CompilerType = CompilerType(PType.bag()) + private val LIST: CompilerType = CompilerType(PType.array()) + private val SEXP: CompilerType = CompilerType(PType.sexp()) + private val INT: CompilerType = CompilerType(PType.numeric()) + private val INT4: CompilerType = CompilerType(PType.integer()) + private val TIMESTAMP: CompilerType = CompilerType(PType.timestamp(6)) + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt new file mode 100644 index 000000000..1f11a8278 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt @@ -0,0 +1,17 @@ +package org.partiql.planner.internal.transforms + +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.expr.Expr + +internal object V1SubstitutionVisitor : AstVisitor> { + override fun defaultReturn(node: AstNode, ctx: Map<*, AstNode>) = node + + override fun visitExpr(node: Expr, ctx: Map<*, AstNode>): AstNode { + val visited = super.visitExpr(node, ctx) + if (ctx.containsKey(visited)) { + return ctx[visited]!! + } + return visited + } +} diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt index e29a0e493..f8ca7fbe8 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt @@ -5,7 +5,7 @@ import org.junit.jupiter.api.DynamicContainer.dynamicContainer import org.junit.jupiter.api.DynamicNode import org.junit.jupiter.api.DynamicTest import org.junit.jupiter.api.TestFactory -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Plan import org.partiql.planner.internal.TestCatalog import org.partiql.planner.test.PartiQLTest @@ -75,7 +75,7 @@ class PlanTest { ) .namespace("SCHEMA") .build() - val ast = PartiQLParser.standard().parse(test.statement).root + val ast = V1PartiQLParser.standard().parse(test.statement).root val planner = PartiQLPlanner.builder().signal(isSignalMode).build() planner.plan(ast, session) } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt index 40ce8c89c..8842518e6 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt @@ -2,8 +2,8 @@ package org.partiql.planner import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource -import org.partiql.ast.Statement -import org.partiql.parser.PartiQLParserBuilder +import org.partiql.ast.v1.Statement +import org.partiql.parser.V1PartiQLParserBuilder import org.partiql.plan.Operation import org.partiql.planner.internal.typer.CompilerType import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType @@ -19,7 +19,6 @@ import org.partiql.types.PType import org.partiql.types.StaticType import org.partiql.types.StructType import org.partiql.types.TupleConstraint -import java.lang.AssertionError import kotlin.test.assertEquals internal class PlannerPErrorReportingTests { @@ -42,7 +41,7 @@ internal class PlannerPErrorReportingTests { .catalogs(catalog) .build() - private val parser = PartiQLParserBuilder().build() + private val parser = V1PartiQLParserBuilder().build() private val statement: ((String) -> Statement) = { query -> parser.parse(query).root diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt index f91cff9d6..b31f3ce80 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt @@ -7,7 +7,7 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Exclusion import org.partiql.plan.Operation import org.partiql.plan.builder.PlanFactory @@ -26,7 +26,7 @@ class SubsumptionTest { companion object { private val planner = PartiQLPlanner.standard() - private val parser = PartiQLParser.standard() + private val parser = V1PartiQLParser.standard() private val catalog = MemoryCatalog.builder().name("default").build() } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt index f91349bd9..6d26fcaf2 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt @@ -2,7 +2,7 @@ package org.partiql.planner.internal.typer import org.junit.jupiter.api.DynamicContainer import org.junit.jupiter.api.DynamicTest -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Operation import org.partiql.planner.PartiQLPlanner import org.partiql.planner.test.PartiQLTest @@ -38,7 +38,7 @@ abstract class PartiQLTyperTestBase { companion object { - public val parser = PartiQLParser.standard() + public val parser = V1PartiQLParser.standard() public val planner = PartiQLPlanner.standard() internal val session: ((String, Catalog) -> Session) = { catalog, metadata -> diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index b4253da0d..a5340499b 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -12,7 +12,7 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.MethodSource -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.PErrors import org.partiql.planner.internal.TestCatalog @@ -125,7 +125,7 @@ internal class PlanTyperTestsPorted { companion object { - private val parser = PartiQLParser.standard() + private val parser = V1PartiQLParser.standard() private val planner = PartiQLPlanner.builder().signal().build() private fun assertProblemExists(problem: PError) = ProblemHandler { problems, _ -> diff --git a/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt b/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt index 4ca4287f5..865c7cebc 100644 --- a/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt +++ b/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt @@ -1,7 +1,7 @@ package org.partiql.lang.randomized.eval import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.planner.PartiQLPlanner import org.partiql.spi.catalog.Catalog import org.partiql.spi.catalog.Session @@ -22,7 +22,7 @@ fun runEvaluatorTestCase( @OptIn(PartiQLValueExperimental::class) private fun execute(query: String): PartiQLValue { - val parser = PartiQLParser.builder().build() + val parser = V1PartiQLParser.builder().build() val planner = PartiQLPlanner.builder().build() val catalog = object : Catalog { override fun getName(): String = "default" diff --git a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt index db58f2a8f..4f6fbb19d 100644 --- a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt +++ b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt @@ -10,7 +10,7 @@ import com.amazon.ionelement.api.toIonValue import org.partiql.eval.Mode import org.partiql.eval.Statement import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Operation.Query import org.partiql.planner.PartiQLPlanner import org.partiql.plugins.memory.MemoryCatalog @@ -141,7 +141,7 @@ class EvalExecutor( companion object { val compiler = PartiQLCompiler.standard() - val parser = PartiQLParser.standard() + val parser = V1PartiQLParser.standard() val planner = PartiQLPlanner.standard() // TODO REPLACE WITH DATUM COMPARATOR val comparator = PartiQLValue.comparator()