diff --git a/partiql-ast/src/main/java/org/partiql/ast/v1/DataType.java b/partiql-ast/src/main/java/org/partiql/ast/v1/DataType.java index b3ee32edd..a62790214 100644 --- a/partiql-ast/src/main/java/org/partiql/ast/v1/DataType.java +++ b/partiql-ast/src/main/java/org/partiql/ast/v1/DataType.java @@ -10,6 +10,7 @@ @EqualsAndHashCode(callSuper = false) public class DataType extends AstEnum { public static final int UNKNOWN = 0; + // TODO remove `NULL` and `MISSING` variants from DataType // public static final int NULL = 1; public static final int MISSING = 2; diff --git a/partiql-ast/src/main/java/org/partiql/ast/v1/expr/ExprIsType.java b/partiql-ast/src/main/java/org/partiql/ast/v1/expr/ExprIsType.java index 8d3744b0b..55769525b 100644 --- a/partiql-ast/src/main/java/org/partiql/ast/v1/expr/ExprIsType.java +++ b/partiql-ast/src/main/java/org/partiql/ast/v1/expr/ExprIsType.java @@ -13,6 +13,7 @@ /** * TODO docs, equals, hashcode + * TODO also support IS NULL, IS MISSING, IS UNKNOWN, IS TRUE, IS FALSE */ @Builder(builderClassName = "Builder") @EqualsAndHashCode(callSuper = false) diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt index 739d48042..90930d9db 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/pipeline/Pipeline.kt @@ -1,11 +1,11 @@ package org.partiql.cli.pipeline -import org.partiql.ast.Statement +import org.partiql.ast.v1.Statement import org.partiql.cli.ErrorCodeString import org.partiql.eval.Mode import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParser -import org.partiql.parser.PartiQLParserBuilder +import org.partiql.parser.V1PartiQLParser +import org.partiql.parser.V1PartiQLParserBuilder import org.partiql.plan.Plan import org.partiql.planner.PartiQLPlanner import org.partiql.spi.Context @@ -13,10 +13,9 @@ import org.partiql.spi.catalog.Session import org.partiql.spi.errors.PErrorListenerException import org.partiql.spi.value.Datum import java.io.PrintStream -import kotlin.jvm.Throws internal class Pipeline private constructor( - private val parser: PartiQLParser, + private val parser: V1PartiQLParser, private val planner: PartiQLPlanner, private val compiler: PartiQLCompiler, private val ctx: Context, @@ -81,7 +80,7 @@ internal class Pipeline private constructor( private fun create(mode: Mode, out: PrintStream, config: Config): Pipeline { val listener = config.getErrorListener(out) val ctx = Context.of(listener) - val parser = PartiQLParserBuilder().build() + val parser = V1PartiQLParserBuilder().build() val planner = PartiQLPlanner.builder().build() val compiler = PartiQLCompiler.builder().build() return Pipeline(parser, planner, compiler, ctx, mode) diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt index 158350ebc..52e850da0 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt @@ -9,7 +9,7 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.partiql.eval.Mode import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Plan import org.partiql.planner.PartiQLPlanner import org.partiql.plugins.memory.MemoryCatalog @@ -1307,7 +1307,7 @@ class PartiQLEvaluatorTest { ) { private val compiler = PartiQLCompiler.standard() - private val parser = PartiQLParser.standard() + private val parser = V1PartiQLParser.standard() private val planner = PartiQLPlanner.standard() /** @@ -1373,7 +1373,7 @@ class PartiQLEvaluatorTest { ) { private val compiler = PartiQLCompiler.standard() - private val parser = PartiQLParser.standard() + private val parser = V1PartiQLParser.standard() private val planner = PartiQLPlanner.standard() internal fun assert() { diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/V1PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/V1PartiQLParserDefault.kt index 7aab14ba1..0caa5228c 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/V1PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/V1PartiQLParserDefault.kt @@ -1509,6 +1509,21 @@ internal class V1PartiQLParserDefault : V1PartiQLParser { exprPathStepAllFields(null) } + override fun visitValues(ctx: GeneratedParser.ValuesContext) = translate(ctx) { + val rows = visitOrEmpty(ctx.valueRow()) + exprBag(rows) + } + + override fun visitValueRow(ctx: GeneratedParser.ValueRowContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + exprArray(expressions) + } + + override fun visitValueList(ctx: GeneratedParser.ValueListContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + exprArray(expressions) + } + override fun visitExprGraphMatchMany(ctx: GeneratedParser.ExprGraphMatchManyContext) = translate(ctx) { val graph = visit(ctx.exprPrimary()) as Expr val pattern = visitGpmlPatternList(ctx.gpmlPatternList()) @@ -1617,10 +1632,11 @@ internal class V1PartiQLParserDefault : V1PartiQLParser { val lhs = visitExpr(ctx.expr(0)) val rhs = visitExpr(ctx.expr(1)) // TODO change to not use PartiQLValue -- https://github.com/partiql/partiql-lang-kotlin/issues/1589 - val fieldLit = exprLit(stringValue(ctx.dt.text.uppercase())) + val fieldLit = ctx.dt.text.lowercase() + // TODO error on invalid datetime fields like TIMEZONE_HOUR and TIMEZONE_MINUTE when { - ctx.DATE_ADD() != null -> exprCall(identifierChain(identifier("DATE_ADD", false), null), listOf(fieldLit, lhs, rhs), null) - ctx.DATE_DIFF() != null -> exprCall(identifierChain(identifier("DATE_DIFF", false), null), listOf(fieldLit, lhs, rhs), null) + ctx.DATE_ADD() != null -> exprCall(identifierChain(identifier("date_add_$fieldLit", false), null), listOf(lhs, rhs), null) + ctx.DATE_DIFF() != null -> exprCall(identifierChain(identifier("date_diff_$fieldLit", false), null), listOf(lhs, rhs), null) else -> throw error(ctx, "Expected DATE_ADD or DATE_DIFF") } } diff --git a/partiql-planner/api/partiql-planner.api b/partiql-planner/api/partiql-planner.api index 1495fe28c..47f81855e 100644 --- a/partiql-planner/api/partiql-planner.api +++ b/partiql-planner/api/partiql-planner.api @@ -1,8 +1,8 @@ public abstract interface class org/partiql/planner/PartiQLPlanner { public static final field Companion Lorg/partiql/planner/PartiQLPlanner$Companion; public static fun builder ()Lorg/partiql/planner/builder/PartiQLPlannerBuilder; - public abstract fun plan (Lorg/partiql/ast/Statement;Lorg/partiql/spi/catalog/Session;)Lorg/partiql/planner/PartiQLPlanner$Result; - public abstract fun plan (Lorg/partiql/ast/Statement;Lorg/partiql/spi/catalog/Session;Lorg/partiql/spi/Context;)Lorg/partiql/planner/PartiQLPlanner$Result; + public abstract fun plan (Lorg/partiql/ast/v1/Statement;Lorg/partiql/spi/catalog/Session;)Lorg/partiql/planner/PartiQLPlanner$Result; + public abstract fun plan (Lorg/partiql/ast/v1/Statement;Lorg/partiql/spi/catalog/Session;Lorg/partiql/spi/Context;)Lorg/partiql/planner/PartiQLPlanner$Result; public static fun standard ()Lorg/partiql/planner/PartiQLPlanner; } @@ -12,7 +12,7 @@ public final class org/partiql/planner/PartiQLPlanner$Companion { } public final class org/partiql/planner/PartiQLPlanner$DefaultImpls { - public static fun plan (Lorg/partiql/planner/PartiQLPlanner;Lorg/partiql/ast/Statement;Lorg/partiql/spi/catalog/Session;)Lorg/partiql/planner/PartiQLPlanner$Result; + public static fun plan (Lorg/partiql/planner/PartiQLPlanner;Lorg/partiql/ast/v1/Statement;Lorg/partiql/spi/catalog/Session;)Lorg/partiql/planner/PartiQLPlanner$Result; } public final class org/partiql/planner/PartiQLPlanner$Result { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt index b393eb0b8..4acacbf4f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt @@ -1,6 +1,6 @@ package org.partiql.planner -import org.partiql.ast.Statement +import org.partiql.ast.v1.Statement import org.partiql.plan.Plan import org.partiql.planner.builder.PartiQLPlannerBuilder import org.partiql.spi.Context diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt index 34bd00937..3bbd30250 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/SqlPlanner.kt @@ -1,15 +1,15 @@ package org.partiql.planner.internal -import org.partiql.ast.Statement -import org.partiql.ast.normalize.normalize +import org.partiql.ast.v1.Statement import org.partiql.plan.Operation import org.partiql.plan.Plan import org.partiql.plan.builder.PlanFactory import org.partiql.plan.rex.Rex import org.partiql.planner.PartiQLPlanner import org.partiql.planner.PartiQLPlannerPass -import org.partiql.planner.internal.transforms.AstToPlan +import org.partiql.planner.internal.normalize.normalize import org.partiql.planner.internal.transforms.PlanTransform +import org.partiql.planner.internal.transforms.V1AstToPlan import org.partiql.planner.internal.typer.PlanTyper import org.partiql.spi.Context import org.partiql.spi.catalog.Session @@ -39,7 +39,7 @@ internal class SqlPlanner( val ast = statement.normalize() // 2. AST to Rel/Rex - val root = AstToPlan.apply(ast, env) + val root = V1AstToPlan.apply(ast, env) // 3. Resolve variables val typer = PlanTyper(env, ctx) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt new file mode 100644 index 000000000..4f25650a3 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1AstToPlan.kt @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package org.partiql.planner.internal.transforms + +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.Query +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.planner.internal.Env +import org.partiql.planner.internal.ir.statementQuery +import org.partiql.spi.catalog.Identifier +import org.partiql.ast.v1.Identifier as AstIdentifier +import org.partiql.ast.v1.IdentifierChain as AstIdentifierChain +import org.partiql.ast.v1.Statement as AstStatement +import org.partiql.planner.internal.ir.Statement as PlanStatement + +/** + * Simple translation from AST to an unresolved algebraic IR. + */ +internal object V1AstToPlan { + + // statement.toPlan() + @JvmStatic + fun apply(statement: AstStatement, env: Env): PlanStatement = statement.accept(ToPlanStatement, env) + + @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") + private object ToPlanStatement : AstVisitor() { + + override fun defaultReturn(node: AstNode, env: Env) = throw IllegalArgumentException("Unsupported statement") + + override fun visitQuery(node: Query, env: Env): PlanStatement { + val rex = when (val expr = node.expr) { + is ExprQuerySet -> V1RelConverter.apply(expr, env) + else -> V1RexConverter.apply(expr, env) + } + return statementQuery(rex) + } + } + + // --- Helpers -------------------- + + fun convert(identifier: AstIdentifierChain): Identifier { + val parts = mutableListOf() + parts.add(part(identifier.root)) + var curStep = identifier.next + while (curStep != null) { + parts.add(part(curStep.root)) + curStep = curStep.next + } + return Identifier.of(parts) + } + + fun convert(identifier: AstIdentifier): Identifier { + return Identifier.of(part(identifier)) + } + + fun part(identifier: AstIdentifier): Identifier.Part = when (identifier.isDelimited) { + true -> Identifier.Part.delimited(identifier.symbol) + false -> Identifier.Part.regular(identifier.symbol) + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt new file mode 100644 index 000000000..00010c296 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RelConverter.kt @@ -0,0 +1,755 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package org.partiql.planner.internal.transforms + +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprVarRef +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.identifierChain +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstRewriter +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.Exclude +import org.partiql.ast.v1.ExcludeStep +import org.partiql.ast.v1.From +import org.partiql.ast.v1.FromExpr +import org.partiql.ast.v1.FromJoin +import org.partiql.ast.v1.FromType +import org.partiql.ast.v1.GroupBy +import org.partiql.ast.v1.GroupByStrategy +import org.partiql.ast.v1.IdentifierChain +import org.partiql.ast.v1.JoinType +import org.partiql.ast.v1.Nulls +import org.partiql.ast.v1.Order +import org.partiql.ast.v1.OrderBy +import org.partiql.ast.v1.QueryBody +import org.partiql.ast.v1.SelectItem +import org.partiql.ast.v1.SelectList +import org.partiql.ast.v1.SelectPivot +import org.partiql.ast.v1.SelectStar +import org.partiql.ast.v1.SelectValue +import org.partiql.ast.v1.SetOpType +import org.partiql.ast.v1.SetQuantifier +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprCall +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.Scope +import org.partiql.planner.internal.Env +import org.partiql.planner.internal.helpers.toBinder +import org.partiql.planner.internal.ir.Rel +import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.rel +import org.partiql.planner.internal.ir.relBinding +import org.partiql.planner.internal.ir.relOpAggregate +import org.partiql.planner.internal.ir.relOpAggregateCallUnresolved +import org.partiql.planner.internal.ir.relOpDistinct +import org.partiql.planner.internal.ir.relOpErr +import org.partiql.planner.internal.ir.relOpExclude +import org.partiql.planner.internal.ir.relOpExcludePath +import org.partiql.planner.internal.ir.relOpExcludeStep +import org.partiql.planner.internal.ir.relOpExcludeTypeCollIndex +import org.partiql.planner.internal.ir.relOpExcludeTypeCollWildcard +import org.partiql.planner.internal.ir.relOpExcludeTypeStructKey +import org.partiql.planner.internal.ir.relOpExcludeTypeStructSymbol +import org.partiql.planner.internal.ir.relOpExcludeTypeStructWildcard +import org.partiql.planner.internal.ir.relOpFilter +import org.partiql.planner.internal.ir.relOpJoin +import org.partiql.planner.internal.ir.relOpLimit +import org.partiql.planner.internal.ir.relOpOffset +import org.partiql.planner.internal.ir.relOpProject +import org.partiql.planner.internal.ir.relOpScan +import org.partiql.planner.internal.ir.relOpScanIndexed +import org.partiql.planner.internal.ir.relOpSort +import org.partiql.planner.internal.ir.relOpSortSpec +import org.partiql.planner.internal.ir.relOpUnpivot +import org.partiql.planner.internal.ir.relType +import org.partiql.planner.internal.ir.rex +import org.partiql.planner.internal.ir.rexOpLit +import org.partiql.planner.internal.ir.rexOpPivot +import org.partiql.planner.internal.ir.rexOpSelect +import org.partiql.planner.internal.ir.rexOpStruct +import org.partiql.planner.internal.ir.rexOpStructField +import org.partiql.planner.internal.ir.rexOpVarLocal +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.types.PType +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.boolValue +import org.partiql.value.int32Value +import org.partiql.value.stringValue + +/** + * Lexically scoped state for use in translating an individual SELECT statement. + */ +internal object V1RelConverter { + + // IGNORE — so we don't have to non-null assert on operator inputs + internal val nil = rel(relType(emptyList(), emptySet()), relOpErr("nil")) + + /** + * Here we convert an SFW to composed [Rel]s, then apply the appropriate relation-value projection to get a [Rex]. + */ + internal fun apply(qSet: ExprQuerySet, env: Env): Rex { + val newQSet = V1NormalizeSelect.normalize(qSet) + val rex = when (val body = newQSet.body) { + is QueryBody.SFW -> { + val rel = newQSet.accept(ToRel(env), nil) + when (val projection = body.select) { + // PIVOT ... FROM + is SelectPivot -> { + val key = projection.key.toRex(env) + val value = projection.value.toRex(env) + val type = (STRUCT) + val op = rexOpPivot(key, value, rel) + rex(type, op) + } + // SELECT VALUE ... FROM + is SelectValue -> { + assert(rel.type.schema.size == 1) { + "Expected SELECT VALUE's input to have a single binding. " + + "However, it contained: ${rel.type.schema.map { it.name }}." + } + val constructor = rex(ANY, rexOpVarLocal(0, 0)) + val op = rexOpSelect(constructor, rel) + val type = when (rel.type.props.contains(Rel.Prop.ORDERED)) { + true -> (LIST) + else -> (BAG) + } + rex(type, op) + } + // SELECT * FROM + is SelectStar -> { + throw IllegalArgumentException("AST not normalized") + } + // SELECT ... FROM + is SelectList -> { + throw IllegalArgumentException("AST not normalized") + } + else -> error("Unexpected Select type: $projection") + } + } + is QueryBody.SetOp -> { + val rel = newQSet.accept(ToRel(env), nil) + val constructor = rex(ANY, rexOpVarLocal(0, 0)) + val op = rexOpSelect(constructor, rel) + val type = when (rel.type.props.contains(Rel.Prop.ORDERED)) { + true -> (LIST) + else -> (BAG) + } + rex(type, op) + } + else -> error("Unexpected QueryBody type: ${newQSet.body}") + } + return rex + } + + /** + * Syntax sugar for converting an [Expr] tree to a [Rex] tree. + */ + private fun Expr.toRex(env: Env): Rex = V1RexConverter.apply(this, env) + + @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE", "LocalVariableName") + internal class ToRel(private val env: Env) : AstVisitor() { + + override fun defaultReturn(node: AstNode, input: Rel): Rel = + throw IllegalArgumentException("unsupported rel $node") + + /** + * Translate SFW AST node to a pipeline of [Rel] operators; skip any SELECT VALUE or PIVOT projection. + */ + + override fun visitExprQuerySet(node: ExprQuerySet, ctx: Rel): Rel { + val body = node.body + val orderBy = node.orderBy + val limit = node.limit + val offset = node.offset + when (body) { + is QueryBody.SFW -> { + var sel = body + var rel = visitFrom(sel.from, nil) + rel = convertWhere(rel, sel.where) + // kotlin does not have destructuring reassignment + val (_sel, _rel) = convertAgg(rel, sel, sel.groupBy) + sel = _sel + rel = _rel + // Plan.create (possibly rewritten) sel node + rel = convertHaving(rel, sel.having) + rel = convertOrderBy(rel, orderBy) + // offset should precede limit + rel = convertOffset(rel, offset) + rel = convertLimit(rel, limit) + rel = convertExclude(rel, sel.exclude) + // append SQL projection if present + rel = when (val projection = sel.select) { + is SelectValue -> { + val project = visitSelectValue(projection, rel) + visitSetQuantifier(projection.setq, project) + } + is SelectStar, is SelectList -> { + error("AST not normalized, found ${projection.javaClass.simpleName}") + } + is SelectPivot -> rel // Skip PIVOT + else -> error("Unexpected Select type: $projection") + } + return rel + } + is QueryBody.SetOp -> { + var rel = convertSetOp(body) + rel = convertOrderBy(rel, orderBy) + // offset should precede limit + rel = convertOffset(rel, offset) + rel = convertLimit(rel, limit) + return rel + } + else -> error("Unexpected QueryBody type: $body") + } + } + + /** + * Given a [setQuantifier], this will return a [Rel] of [Rel.Op.Distinct] wrapping the [input]. + * If [setQuantifier] is null or ALL, this will return the [input]. + */ + private fun visitSetQuantifier(setQuantifier: SetQuantifier?, input: Rel): Rel { + return when (setQuantifier?.code()) { + SetQuantifier.DISTINCT -> rel(input.type, relOpDistinct(input)) + SetQuantifier.ALL, null -> input + else -> error("Unexpected SetQuantifier type: $setQuantifier") + } + } + + override fun visitSelectList(node: SelectList, input: Rel): Rel { + // this ignores aggregations + val schema = mutableListOf() + val props = input.type.props + val projections = mutableListOf() + node.items.forEach { + val (binding, projection) = convertSelectItem(it) + schema.add(binding) + projections.add(projection) + } + val type = relType(schema, props) + val op = relOpProject(input, projections) + return rel(type, op) + } + + override fun visitSelectValue(node: SelectValue, input: Rel): Rel { + val name = node.constructor.toBinder(1).symbol + val rex = V1RexConverter.apply(node.constructor, env) + val schema = listOf(relBinding(name, rex.type)) + val props = input.type.props + val type = relType(schema, props) + val op = relOpProject(input, projections = listOf(rex)) + return rel(type, op) + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitFrom(node: From, ctx: Rel): Rel { + val tableRefs = node.tableRefs.map { visitFromTableRef(it, ctx) } + return tableRefs.drop(1).fold(tableRefs.first()) { acc, tRef -> + val joinType = Rel.Op.Join.Type.INNER + val condition = rex(BOOL, rexOpLit(boolValue(true))) + val schema = acc.type.schema + tRef.type.schema + val props = emptySet() + val type = relType(schema, props) + val op = relOpJoin(acc, tRef, condition, joinType) + rel(type, op) + } + } + + override fun visitFromExpr(node: FromExpr, nil: Rel): Rel { + val rex = V1RexConverter.applyRel(node.expr, env) + val binding = when (val a = node.asAlias) { + null -> error("AST not normalized, missing AS alias on $node") + else -> relBinding( + name = a.symbol, + type = rex.type + ) + } + return when (node.fromType.code()) { + FromType.SCAN -> { + when (val i = node.atAlias) { + null -> convertScan(rex, binding) + else -> { + val index = relBinding( + name = i.symbol, + type = (INT) + ) + convertScanIndexed(rex, binding, index) + } + } + } + FromType.UNPIVOT -> { + val atAlias = when (val at = node.atAlias) { + null -> error("AST not normalized, missing AT alias on UNPIVOT $node") + else -> relBinding( + name = at.symbol, + type = (STRING) + ) + } + convertUnpivot(rex, k = atAlias, v = binding) + } + else -> error("Unexpected FromType type: ${node.fromType}") + } + } + + /** + * Appends [Rel.Op.Join] where the left and right sides are converted FROM sources + * + * TODO compute basic schema + */ + @OptIn(PartiQLValueExperimental::class) + override fun visitFromJoin(node: FromJoin, nil: Rel): Rel { + val lhs = visitFromTableRef(node.lhs, nil) + val rhs = visitFromTableRef(node.rhs, nil) + val schema = lhs.type.schema + rhs.type.schema // Note: This gets more specific in PlanTyper. It is only used to find binding names here. + val props = emptySet() + val condition = node.condition?.let { V1RexConverter.apply(it, env) } ?: rex(BOOL, rexOpLit(boolValue(true))) + val joinType = when (node.joinType?.code()) { + JoinType.LEFT_OUTER, JoinType.LEFT, JoinType.LEFT_CROSS -> Rel.Op.Join.Type.LEFT + JoinType.RIGHT_OUTER, JoinType.RIGHT -> Rel.Op.Join.Type.RIGHT + JoinType.FULL_OUTER, JoinType.FULL -> Rel.Op.Join.Type.FULL + JoinType.INNER, + JoinType.CROSS -> Rel.Op.Join.Type.INNER // Cross Joins are just INNER JOIN ON TRUE + null -> Rel.Op.Join.Type.INNER // a JOIN b ON a.id = b.id <--> a INNER JOIN b ON a.id = b.id + else -> error("Unexpected JoinType type: ${node.joinType}") + } + val type = relType(schema, props) + val op = relOpJoin(lhs, rhs, condition, joinType) + return rel(type, op) + } + + // Helpers + private fun convertScan(rex: Rex, binding: Rel.Binding): Rel { + val schema = listOf(binding) + val props = emptySet() + val type = relType(schema, props) + val op = relOpScan(rex) + return rel(type, op) + } + + private fun convertScanIndexed(rex: Rex, binding: Rel.Binding, index: Rel.Binding): Rel { + val schema = listOf(binding, index) + val props = emptySet() + val type = relType(schema, props) + val op = relOpScanIndexed(rex) + return rel(type, op) + } + + /** + * Output schema of an UNPIVOT is < k, v > + * + * @param rex + * @param k + * @param v + */ + private fun convertUnpivot(rex: Rex, k: Rel.Binding, v: Rel.Binding): Rel { + val schema = listOf(k, v) + val props = emptySet() + val type = relType(schema, props) + val op = relOpUnpivot(rex) + return rel(type, op) + } + + private fun convertSelectItem(item: SelectItem) = when (item) { + is SelectItem.Star -> convertSelectItemStar(item) + is SelectItem.Expr -> convertSelectItemExpr(item) + else -> error("Unexpected SelectItem type: $item") + } + + private fun convertSelectItemStar(item: SelectItem.Star): Pair { + throw IllegalArgumentException("AST not normalized") + } + + private fun convertSelectItemExpr(item: SelectItem.Expr): Pair { + val name = when (val a = item.asAlias) { + null -> error("AST not normalized, missing AS alias on select item $item") + else -> a.symbol + } + val rex = V1RexConverter.apply(item.expr, env) + val binding = relBinding(name, rex.type) + return binding to rex + } + + /** + * Append [Rel.Op.Filter] only if a WHERE condition exists + */ + private fun convertWhere(input: Rel, expr: Expr?): Rel { + if (expr == null) { + return input + } + val type = input.type + val predicate = expr.toRex(env) + val op = relOpFilter(input, predicate) + return rel(type, op) + } + + /** + * Append [Rel.Op.Aggregate] only if SELECT contains aggregate expressions. + * + * TODO Set quantifiers + * TODO Group As + * + * @return Pair is returned where + * 1. Ast.Expr.SFW has every Ast.Expr.CallAgg replaced by a synthetic Ast.Expr.Var + * 2. Rel which has the appropriate Rex.Agg calls and groups + */ + @OptIn(PartiQLValueExperimental::class) + private fun convertAgg(input: Rel, select: QueryBody.SFW, groupBy: GroupBy?): Pair { + // Rewrite and extract all aggregations in the SELECT clause + val (sel, aggregations) = AggregationTransform.apply(select) + + // No aggregation planning required for GROUP BY + if (aggregations.isEmpty() && groupBy == null) { + return Pair(select, input) + } + + // Build the schema -> (calls... groups...) + val schema = mutableListOf() + val props = emptySet() + + // Build the rel operator + var strategy = Rel.Op.Aggregate.Strategy.FULL + val calls = aggregations.mapIndexed { i, expr -> + val binding = relBinding( + name = syntheticAgg(i), + type = (ANY), + ) + schema.add(binding) + val args = expr.args.map { arg -> arg.toRex(env) } + val id = V1AstToPlan.convert(expr.function) + if (id.hasQualifier()) { + error("Qualified aggregation calls are not supported.") + } + // lowercase normalize all calls + val name = id.getIdentifier().getText().lowercase() + if (name == "count" && expr.args.isEmpty()) { + relOpAggregateCallUnresolved( + name, + org.partiql.planner.internal.ir.SetQuantifier.ALL, + args = listOf(exprLit(int32Value(1)).toRex(env)) + ) + } else { + val setq = when (expr.setq?.code()) { + null -> org.partiql.planner.internal.ir.SetQuantifier.ALL + SetQuantifier.ALL -> org.partiql.planner.internal.ir.SetQuantifier.ALL + SetQuantifier.DISTINCT -> org.partiql.planner.internal.ir.SetQuantifier.DISTINCT + else -> error("Unexpected SetQuantifier type: ${expr.setq}") + } + relOpAggregateCallUnresolved(name, setq, args) + } + }.toMutableList() + + // Add GROUP_AS aggregation + groupBy?.let { gb -> + gb.asAlias?.let { groupAs -> + val binding = relBinding(groupAs.symbol, ANY) + schema.add(binding) + val fields = input.type.schema.mapIndexed { bindingIndex, currBinding -> + rexOpStructField( + k = rex(STRING, rexOpLit(stringValue(currBinding.name))), + v = rex(ANY, rexOpVarLocal(0, bindingIndex)) + ) + } + val arg = listOf(rex(ANY, rexOpStruct(fields))) + calls.add(relOpAggregateCallUnresolved("group_as", org.partiql.planner.internal.ir.SetQuantifier.ALL, arg)) + } + } + var groups = emptyList() + if (groupBy != null) { + groups = groupBy.keys.map { + if (it.asAlias == null) { + error("not normalized, group key $it missing unique name") + } + val binding = relBinding( + name = it.asAlias!!.symbol, + type = (ANY) + ) + schema.add(binding) + it.expr.toRex(env) + } + strategy = when (groupBy.strategy.code()) { + GroupByStrategy.FULL -> Rel.Op.Aggregate.Strategy.FULL + GroupByStrategy.PARTIAL -> Rel.Op.Aggregate.Strategy.PARTIAL + else -> error("Unexpected GroupByStrategy type: ${groupBy.strategy}") + } + } + val type = relType(schema, props) + val op = relOpAggregate(input, strategy, calls, groups) + val rel = rel(type, op) + return Pair(sel, rel) + } + + /** + * Append [Rel.Op.Filter] only if a HAVING condition exists + * + * Notes: + * - This currently does not support aggregation expressions in the WHERE condition + */ + private fun convertHaving(input: Rel, expr: Expr?): Rel { + if (expr == null) { + return input + } + val type = input.type + val predicate = expr.toRex(env) + val op = relOpFilter(input, predicate) + return rel(type, op) + } + + private fun visitIfQuerySet(expr: Expr): Rel { + return when (expr) { + is ExprQuerySet -> visit(expr, nil) + else -> { + val rex = V1RexConverter.applyRel(expr, env) + val op = relOpScan(rex) + val type = Rel.Type(listOf(Rel.Binding("_1", ANY)), props = emptySet()) + return rel(type, op) + } + } + } + + /** + * Append SQL set operator if present + */ + private fun convertSetOp(setExpr: QueryBody.SetOp): Rel { + val lhs = visitIfQuerySet(setExpr.lhs) + val rhs = visitIfQuerySet(setExpr.rhs) + val type = Rel.Type(listOf(Rel.Binding("_0", ANY)), props = emptySet()) + val quantifier = when (setExpr.type.setq?.code()) { + SetQuantifier.ALL -> org.partiql.planner.internal.ir.SetQuantifier.ALL + null, SetQuantifier.DISTINCT -> org.partiql.planner.internal.ir.SetQuantifier.DISTINCT + else -> error("Unexpected SetQuantifier type: ${setExpr.type.setq}") + } + val outer = setExpr.isOuter + val op = when (setExpr.type.setOpType.code()) { + SetOpType.UNION -> Rel.Op.Union(quantifier, outer, lhs, rhs) + SetOpType.EXCEPT -> Rel.Op.Except(quantifier, outer, lhs, rhs) + SetOpType.INTERSECT -> Rel.Op.Intersect(quantifier, outer, lhs, rhs) + else -> error("Unexpected SetOpType type: ${setExpr.type.setOpType}") + } + return rel(type, op) + } + + /** + * Append [Rel.Op.Sort] only if an ORDER BY clause is present + */ + private fun convertOrderBy(input: Rel, orderBy: OrderBy?): Rel { + if (orderBy == null) { + return input + } + val type = input.type.copy(props = setOf(Rel.Prop.ORDERED)) + val specs = orderBy.sorts.map { + val rex = it.expr.toRex(env) + val order = when (it.order?.code()) { + Order.DESC -> when (it.nulls?.code()) { + Nulls.LAST -> Rel.Op.Sort.Order.DESC_NULLS_LAST + Nulls.FIRST, null -> Rel.Op.Sort.Order.DESC_NULLS_FIRST + else -> error("Unexpected Nulls type: ${it.nulls}") + } + else -> when (it.nulls?.code()) { + Nulls.FIRST -> Rel.Op.Sort.Order.ASC_NULLS_FIRST + Nulls.LAST, null -> Rel.Op.Sort.Order.ASC_NULLS_LAST + else -> error("Unexpected Nulls type: ${it.nulls}") + } + } + relOpSortSpec(rex, order) + } + val op = relOpSort(input, specs) + return rel(type, op) + } + + /** + * Append [Rel.Op.Limit] if there is a LIMIT + */ + private fun convertLimit(input: Rel, limit: Expr?): Rel { + if (limit == null) { + return input + } + val type = input.type + val rex = V1RexConverter.apply(limit, env) + val op = relOpLimit(input, rex) + return rel(type, op) + } + + /** + * Append [Rel.Op.Offset] if there is an OFFSET + */ + private fun convertOffset(input: Rel, offset: Expr?): Rel { + if (offset == null) { + return input + } + val type = input.type + val rex = V1RexConverter.apply(offset, env) + val op = relOpOffset(input, rex) + return rel(type, op) + } + + private fun convertExclude(input: Rel, exclude: Exclude?): Rel { + if (exclude == null) { + return input + } + val type = input.type // PlanTyper handles typing the exclusion and removing redundant exclude paths + val paths = exclude.excludePaths + .groupBy(keySelector = { it.root }, valueTransform = { it.excludeSteps }) + .map { (root, exclusions) -> + val rootVar = (root.toRex(env)).op as Rex.Op.Var + val steps = exclusionsToSteps(exclusions) + relOpExcludePath(rootVar, steps) + } + val op = relOpExclude(input, paths) + return rel(type, op) + } + + private fun exclusionsToSteps(exclusions: List>): List { + if (exclusions.any { it.isEmpty() }) { + // if there exists a path with no further steps, can remove the longer paths + // e.g. t.a.b, t.a.b.c, t.a.b.d[*].*.e -> can keep just t.a.b + return emptyList() + } + return exclusions + .groupBy(keySelector = { it.first() }, valueTransform = { it.drop(1) }) + .map { (head, steps) -> + val type = stepToExcludeType(head) + val substeps = exclusionsToSteps(steps) + relOpExcludeStep(type, substeps) + } + } + + private fun stepToExcludeType(step: ExcludeStep): Rel.Op.Exclude.Type { + return when (step) { + is ExcludeStep.StructField -> { + when (step.symbol.isDelimited) { + false -> relOpExcludeTypeStructSymbol(step.symbol.symbol) + true -> relOpExcludeTypeStructKey(step.symbol.symbol) + } + } + is ExcludeStep.CollIndex -> relOpExcludeTypeCollIndex(step.index) + is ExcludeStep.StructWildcard -> relOpExcludeTypeStructWildcard() + is ExcludeStep.CollWildcard -> relOpExcludeTypeCollWildcard() + else -> error("Unexpected ExcludeStep type: $step") + } + } + + // /** + // * Converts a GROUP AS X clause to a binding of the form: + // * ``` + // * { 'X': group_as({ 'a_0': e_0, ..., 'a_n': e_n }) } + // * ``` + // * + // * Notes: + // * - This was included to be consistent with the existing PartiqlAst and PartiqlLogical representations, + // * but perhaps we don't want to represent GROUP AS with an agg function. + // */ + // private fun convertGroupAs(name: String, from: From): Binding { + // val fields = from.bindings().map { n -> + // Plan.field( + // name = Plan.rexLit(ionString(n), STRING), + // value = Plan.rexId(n, Case.SENSITIVE, Rex.Id.Qualifier.UNQUALIFIED, type = STRUCT) + // ) + // } + // return Plan.binding( + // name = name, + // value = Plan.rexAgg( + // id = "group_as", + // args = listOf(Plan.rexTuple(fields, STRUCT)), + // modifier = Rex.Agg.Modifier.ALL, + // type = STRUCT + // ) + // ) + // } + } + + /** + * Rewrites a SELECT node replacing (and extracting) each aggregation `i` with a synthetic field name `$agg_i`. + */ + private object AggregationTransform : AstRewriter() { + // currently hard-coded + @JvmStatic + private val aggregates = setOf("count", "avg", "sum", "min", "max", "any", "some", "every") + + private data class Context( + val aggregations: MutableList, + val keys: List + ) + + fun apply(node: QueryBody.SFW): Pair> { + val aggs = mutableListOf() + val keys = node.groupBy?.keys ?: emptyList() + val context = Context(aggs, keys) + val select = super.visitQueryBodySFW(node, context) as QueryBody.SFW + return Pair(select, aggs) + } + + override fun visitSelectValue(node: SelectValue, ctx: Context): AstNode { + val visited = super.visitSelectValue(node, ctx) + val substitutions = ctx.keys.associate { + it.expr to exprVarRef(identifierChain(identifier(it.asAlias!!.symbol, isDelimited = true), next = null), Scope.DEFAULT()) + } + return V1SubstitutionVisitor.visit(visited, substitutions) + } + + // only rewrite top-level SFW + override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: Context): AstNode = node + + override fun visitExprCall(node: ExprCall, ctx: Context) = + // TODO replace w/ proper function resolution to determine whether a function call is a scalar or aggregate. + // may require further modification of SPI interfaces to support + when (node.function.isAggregateCall()) { + true -> { + val id = identifierChain( + identifier( + symbol = syntheticAgg(ctx.aggregations.size), + isDelimited = false + ), + next = null + ) + ctx.aggregations += node + exprVarRef(id, Scope.DEFAULT()) + } + else -> node + } + + private fun String.isAggregateCall(): Boolean { + return aggregates.contains(this) + } + + private fun IdentifierChain.isAggregateCall(): Boolean { + return when (next) { + null -> root.symbol.lowercase().isAggregateCall() + else -> { + var curId = next + var last = curId + while (curId != null) { + last = curId + curId = curId.next + } + last!!.root.symbol.lowercase().isAggregateCall() + } + } + } + + override fun defaultReturn(node: AstNode, ctx: Context) = node + } + + private fun syntheticAgg(i: Int) = "\$agg_$i" + + private val ANY: CompilerType = CompilerType(PType.dynamic()) + private val BOOL: CompilerType = CompilerType(PType.bool()) + private val STRING: CompilerType = CompilerType(PType.string()) + private val STRUCT: CompilerType = CompilerType(PType.struct()) + private val BAG: CompilerType = CompilerType(PType.bag()) + private val LIST: CompilerType = CompilerType(PType.array()) + private val INT: CompilerType = CompilerType(PType.numeric()) +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt new file mode 100644 index 000000000..cd203329a --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1RexConverter.kt @@ -0,0 +1,1077 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package org.partiql.planner.internal.transforms + +import com.amazon.ionelement.api.loadSingleElement +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstVisitor +import org.partiql.ast.v1.DataType +import org.partiql.ast.v1.QueryBody +import org.partiql.ast.v1.SelectList +import org.partiql.ast.v1.SelectStar +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprAnd +import org.partiql.ast.v1.expr.ExprArray +import org.partiql.ast.v1.expr.ExprBag +import org.partiql.ast.v1.expr.ExprBetween +import org.partiql.ast.v1.expr.ExprCall +import org.partiql.ast.v1.expr.ExprCase +import org.partiql.ast.v1.expr.ExprCast +import org.partiql.ast.v1.expr.ExprCoalesce +import org.partiql.ast.v1.expr.ExprExtract +import org.partiql.ast.v1.expr.ExprInCollection +import org.partiql.ast.v1.expr.ExprIsType +import org.partiql.ast.v1.expr.ExprLike +import org.partiql.ast.v1.expr.ExprLit +import org.partiql.ast.v1.expr.ExprNot +import org.partiql.ast.v1.expr.ExprNullIf +import org.partiql.ast.v1.expr.ExprOperator +import org.partiql.ast.v1.expr.ExprOr +import org.partiql.ast.v1.expr.ExprOverlay +import org.partiql.ast.v1.expr.ExprPath +import org.partiql.ast.v1.expr.ExprPosition +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.ExprSessionAttribute +import org.partiql.ast.v1.expr.ExprStruct +import org.partiql.ast.v1.expr.ExprSubstring +import org.partiql.ast.v1.expr.ExprTrim +import org.partiql.ast.v1.expr.ExprVarRef +import org.partiql.ast.v1.expr.ExprVariant +import org.partiql.ast.v1.expr.PathStep +import org.partiql.ast.v1.expr.Scope +import org.partiql.ast.v1.expr.TrimSpec +import org.partiql.errors.TypeCheckException +import org.partiql.planner.internal.Env +import org.partiql.planner.internal.ir.Rel +import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.builder.plan +import org.partiql.planner.internal.ir.rel +import org.partiql.planner.internal.ir.relBinding +import org.partiql.planner.internal.ir.relOpJoin +import org.partiql.planner.internal.ir.relOpScan +import org.partiql.planner.internal.ir.relOpUnpivot +import org.partiql.planner.internal.ir.relType +import org.partiql.planner.internal.ir.rex +import org.partiql.planner.internal.ir.rexOpCallUnresolved +import org.partiql.planner.internal.ir.rexOpCastUnresolved +import org.partiql.planner.internal.ir.rexOpCoalesce +import org.partiql.planner.internal.ir.rexOpCollection +import org.partiql.planner.internal.ir.rexOpLit +import org.partiql.planner.internal.ir.rexOpNullif +import org.partiql.planner.internal.ir.rexOpPathIndex +import org.partiql.planner.internal.ir.rexOpPathKey +import org.partiql.planner.internal.ir.rexOpPathSymbol +import org.partiql.planner.internal.ir.rexOpSelect +import org.partiql.planner.internal.ir.rexOpStruct +import org.partiql.planner.internal.ir.rexOpStructField +import org.partiql.planner.internal.ir.rexOpSubquery +import org.partiql.planner.internal.ir.rexOpTupleUnion +import org.partiql.planner.internal.ir.rexOpVarLocal +import org.partiql.planner.internal.ir.rexOpVarUnresolved +import org.partiql.planner.internal.typer.CompilerType +import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType +import org.partiql.spi.catalog.Identifier +import org.partiql.types.PType +import org.partiql.value.MissingValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.StringValue +import org.partiql.value.boolValue +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.io.PartiQLValueIonReaderBuilder +import org.partiql.value.nullValue +import org.partiql.value.stringValue +import org.partiql.ast.v1.SetQuantifier as AstSetQuantifier + +/** + * Converts an AST expression node to a Plan Rex node; ignoring any typing. + */ +internal object V1RexConverter { + + internal fun apply(expr: Expr, context: Env): Rex = ToRex.visitExprCoerce(expr, context) + + internal fun applyRel(expr: Expr, context: Env): Rex = expr.accept(ToRex, context) + + @OptIn(PartiQLValueExperimental::class) + @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") + private object ToRex : AstVisitor() { + + private val COLL_AGG_NAMES = setOf( + "coll_any", + "coll_avg", + "coll_count", + "coll_every", + "coll_max", + "coll_min", + "coll_some", + "coll_sum", + ) + + override fun defaultReturn(node: AstNode, context: Env): Rex = + throw IllegalArgumentException("unsupported rex $node") + + override fun visitExprLit(node: ExprLit, context: Env): Rex { + val type = CompilerType( + _delegate = node.value.type.toPType(), + isNullValue = node.value.isNull, + isMissingValue = node.value is MissingValue + ) + val op = rexOpLit(node.value) + return rex(type, op) + } + + /** + * TODO PartiQLValue will be replaced by Datum (i.e. IonDatum) is a subsequent PR. + */ + override fun visitExprVariant(node: ExprVariant, ctx: Env): Rex { + if (node.encoding != "ion") { + throw IllegalArgumentException("unsupported encoding ${node.encoding}") + } + val ion = loadSingleElement(node.value) + val value = PartiQLValueIonReaderBuilder.standard().build(ion).read() + val type = CompilerType(value.type.toPType()) + return rex(type, rexOpLit(value)) + } + + /** + * !! IMPORTANT !! + * + * This is the top-level visit for handling subquery coercion. The default behavior is to coerce to a scalar. + * In some situations, ie comparison to complex types we may make assertions on the desired type. + * + * It is recommended that every method (except for the exceptional cases) recurse the tree from visitExprCoerce. + * + * - RHS of comparison when LHS is an array or collection expression; and visa-versa + * - It is the collection expression of a FROM clause or JOIN + * - It is the RHS of an IN predicate + * - It is an argument of an OUTER set operator. + * + * @param node + * @param ctx + * @return + */ + internal fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex { + val rex = node.accept(this, ctx) + return when (isSqlSelect(node)) { + true -> { + val select = rex.op as Rex.Op.Select + rex( + CompilerType(PType.dynamic()), + rexOpSubquery( + constructor = select.constructor, + rel = select.rel, + coercion = coercion + ) + ) + } + false -> rex + } + } + + override fun visitExprVarRef(node: ExprVarRef, context: Env): Rex { + val type = (ANY) + val identifier = V1AstToPlan.convert(node.identifierChain) + val scope = when (node.scope.code()) { + Scope.DEFAULT -> Rex.Op.Var.Scope.DEFAULT + Scope.LOCAL -> Rex.Op.Var.Scope.LOCAL + else -> error("Unexpected Scope type: ${node.scope}") + } + val op = rexOpVarUnresolved(identifier, scope) + return rex(type, op) + } + + private fun resolveUnaryOp(symbol: String, rhs: Expr, context: Env): Rex { + val type = (ANY) + // Args + val arg = visitExprCoerce(rhs, context) + val args = listOf(arg) + // Fn + val name = when (symbol) { + // TODO move hard-coded operator resolution into SPI + "+" -> "pos" + "-" -> "neg" + else -> error("unsupported unary op $symbol") + } + val id = Identifier.delimited(name) + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + private fun resolveBinaryOp(lhs: Expr, symbol: String, rhs: Expr, context: Env): Rex { + val type = (ANY) + val args = when (symbol) { + "<", ">", + "<=", ">=", + "=", "<>", "!=" -> { + when { + // Example: [1, 2] < (SELECT a, b FROM t) + isLiteralArray(lhs) && isSqlSelect(rhs) -> { + val l = visitExprCoerce(lhs, context) + val r = visitExprCoerce(rhs, context, Rex.Op.Subquery.Coercion.ROW) + listOf(l, r) + } + // Example: (SELECT a, b FROM t) < [1, 2] + isSqlSelect(lhs) && isLiteralArray(rhs) -> { + val l = visitExprCoerce(lhs, context, Rex.Op.Subquery.Coercion.ROW) + val r = visitExprCoerce(rhs, context) + listOf(l, r) + } + // Example: 1 < 2 + else -> { + val l = visitExprCoerce(lhs, context) + val r = visitExprCoerce(rhs, context) + listOf(l, r) + } + } + } + // Example: 1 + 2 + else -> { + val l = visitExprCoerce(lhs, context) + val r = visitExprCoerce(rhs, context) + listOf(l, r) + } + } + // Wrap if a NOT, if necessary + return when (symbol) { + "<>", "!=" -> { + val op = negate(call("eq", *args.toTypedArray())) + rex(type, op) + } + else -> { + val name = when (symbol) { + // TODO eventually move hard-coded operator resolution into SPI + "<" -> "lt" + ">" -> "gt" + "<=" -> "lte" + ">=" -> "gte" + "=" -> "eq" + "||" -> "concat" + "+" -> "plus" + "-" -> "minus" + "*" -> "times" + "/" -> "divide" + "%" -> "modulo" + "&" -> "bitwise_and" + else -> error("unsupported binary op $symbol") + } + val id = Identifier.delimited(name) + val op = rexOpCallUnresolved(id, args) + rex(type, op) + } + } + } + + override fun visitExprOperator(node: ExprOperator, ctx: Env): Rex { + val lhs = node.lhs + return if (lhs != null) { + resolveBinaryOp(lhs, node.symbol, node.rhs, ctx) + } else { + resolveUnaryOp(node.symbol, node.rhs, ctx) + } + } + + override fun visitExprNot(node: ExprNot, ctx: Env): Rex { + val type = (ANY) + // Args + val arg = visitExprCoerce(node.value, ctx) + val args = listOf(arg) + // Fn + val id = Identifier.delimited("not") + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + override fun visitExprAnd(node: ExprAnd, ctx: Env): Rex { + val type = (ANY) + val l = visitExprCoerce(node.lhs, ctx) + val r = visitExprCoerce(node.rhs, ctx) + val args = listOf(l, r) + + // Wrap if a NOT, if necessary + val id = Identifier.delimited("and") + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + override fun visitExprOr(node: ExprOr, ctx: Env): Rex { + val type = (ANY) + val l = visitExprCoerce(node.lhs, ctx) + val r = visitExprCoerce(node.rhs, ctx) + val args = listOf(l, r) + + // Wrap if a NOT, if necessary + val id = Identifier.delimited("or") + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + private fun isLiteralArray(node: Expr): Boolean = node is ExprArray + + private fun isSqlSelect(node: Expr): Boolean { + return if (node is ExprQuerySet) { + val body = node.body + body is QueryBody.SFW && (body.select is SelectList || body.select is SelectStar) + } else { + false + } + } + + override fun visitExprPath(node: ExprPath, context: Env): Rex { + // Args + val root = visitExprCoerce(node.root, context) + + // Attempt to create qualified identifier + val (newRoot, nextStep) = when (val op = root.op) { + is Rex.Op.Var.Unresolved -> { + // convert consecutive symbol path steps to the root identifier + var i = 0 + val parts = mutableListOf() + parts.addAll(op.identifier.getParts()) + var curStep = node.next + while (curStep != null) { + if (curStep !is PathStep.Field) { + break + } + parts.add(V1AstToPlan.part(curStep.field)) + i += 1 + curStep = curStep.next + } + val newRoot = rex(ANY, rexOpVarUnresolved(Identifier.of(parts), op.scope)) + val newSteps = curStep + newRoot to newSteps + } + else -> { + root to node.next + } + } + + if (nextStep == null) { + return newRoot + } + + val fromList = mutableListOf() + + var varRefIndex = 0 // tracking var ref index + + var curStep = nextStep + var curPathNavi = newRoot + while (curStep != null) { + val path = when (curStep) { + is PathStep.Element -> { + val key = visitExprCoerce(curStep.element, context) + val op = when (val astKey = curStep.element) { + is ExprLit -> when (astKey.value) { + is StringValue -> rexOpPathKey(curPathNavi, key) + else -> rexOpPathIndex(curPathNavi, key) + } + is ExprCast -> when (astKey.asType.code() == DataType.STRING) { + true -> rexOpPathKey(curPathNavi, key) + false -> rexOpPathIndex(curPathNavi, key) + } + else -> rexOpPathIndex(curPathNavi, key) + } + op + } + + is PathStep.Field -> { + when (curStep.field.isDelimited) { + true -> { + // case-sensitive path step becomes a key lookup + rexOpPathKey(curPathNavi, rexString(curStep.field.symbol)) + } + false -> { + // case-insensitive path step becomes a symbol lookup + rexOpPathSymbol(curPathNavi, curStep.field.symbol) + } + } + } + + // Unpivot and Wildcard steps trigger the rewrite + // According to spec Section 4.3 + // ew1p1...wnpn + // rewrite to: + // SELECT VALUE v_n.p_n + // FROM + // u_1 e as v_1 + // u_2 @v_1.p_1 as v_2 + // ... + // u_n @v_(n-1).p_(n-1) as v_n + // The From clause needs to be rewritten to + // Join <------------------- schema: [(k_1), v_1, (k_2), v_2, ..., (k_(n-1)) v_(n-1)] + // / \ + // ... un @v_(n-1).p_(n-1) <-- stack: [global, typeEnv: [outer: [global], schema: [(k_1), v_1, (k_2), v_2, ..., (k_(n-1)) v_(n-1)]]] + // Join <----------------------- schema: [(k_1), v_1, (k_2), v_2, (k_3), v_3] + // / \ + // u_2 @v_1.p_1 as v2 <------- stack: [global, typeEnv: [outer: [global], schema: [(k_1), v_1, (k_2), v_2]]] + // JOIN <---------------------------- schema: [(k_1), v_1, (k_2), v_2] + // / \ + // u_1 e as v_1 < ----\----------------------- stack: [global] + // u_2 @v_1.p_1 as v2 <------ stack: [global, typeEnv: [outer: [global], schema: [(k_1), v_1]]] + // while doing the traversal, instead of passing the stack, + // each join will produce its own schema and pass the schema as a type Env. + // The (k_i) indicate the possible key binding produced by unpivot. + // We calculate the var ref on the fly. + is PathStep.AllFields -> { + // Unpivot produces two binding, in this context we want the value, + // which always going to be the second binding + val op = rexOpVarLocal(1, varRefIndex + 1) + varRefIndex += 2 + val index = fromList.size + fromList.add(relFromUnpivot(curPathNavi, index)) + op + } + is PathStep.AllElements -> { + // Scan produce only one binding + val op = rexOpVarLocal(1, varRefIndex) + varRefIndex += 1 + val index = fromList.size + fromList.add(relFromDefault(curPathNavi, index)) + op + } + else -> error("Unexpected PathStep type: $curStep") + } + curStep = curStep.next + curPathNavi = rex(ANY, path) + } + if (fromList.size == 0) return curPathNavi + val fromNode = fromList.reduce { acc, scan -> + val schema = acc.type.schema + scan.type.schema + val props = emptySet() + val type = relType(schema, props) + rel(type, relOpJoin(acc, scan, rex(BOOL, rexOpLit(boolValue(true))), Rel.Op.Join.Type.INNER)) + } + + // compute the ref used by select construct + // always going to be the last binding + val selectRef = fromNode.type.schema.size - 1 + + val constructor = when (val op = curPathNavi.op) { + is Rex.Op.Path.Index -> rex(curPathNavi.type, rexOpPathIndex(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) + is Rex.Op.Path.Key -> rex(curPathNavi.type, rexOpPathKey(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) + is Rex.Op.Path.Symbol -> rex(curPathNavi.type, rexOpPathSymbol(rex(op.root.type, rexOpVarLocal(0, selectRef)), op.key)) + is Rex.Op.Var.Local -> rex(curPathNavi.type, rexOpVarLocal(0, selectRef)) + else -> throw IllegalStateException() + } + val op = rexOpSelect(constructor, fromNode) + return rex(ANY, op) + } + + /** + * Construct Rel(Scan([path])). + * + * The constructed rel would produce one binding: _v$[index] + */ + private fun relFromDefault(path: Rex, index: Int): Rel { + val schema = listOf( + relBinding( + name = "_v$index", // fresh variable + type = path.type + ) + ) + val props = emptySet() + val relType = relType(schema, props) + return rel(relType, relOpScan(path)) + } + + /** + * Construct Rel(Unpivot([path])). + * + * The constructed rel would produce two bindings: _k$[index] and _v$[index] + */ + private fun relFromUnpivot(path: Rex, index: Int): Rel { + val schema = listOf( + relBinding( + name = "_k$index", // fresh variable + type = STRING + ), + relBinding( + name = "_v$index", // fresh variable + type = path.type + ) + ) + val props = emptySet() + val relType = relType(schema, props) + return rel(relType, relOpUnpivot(path)) + } + + private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) + + override fun visitExprCall(node: ExprCall, context: Env): Rex { + val type = (ANY) + // Fn + val id = V1AstToPlan.convert(node.function) + if (id.hasQualifier()) { + error("Qualified function calls are not currently supported.") + } + if (id.matches("TUPLEUNION")) { + return visitExprCallTupleUnion(node, context) + } + if (id.matches("EXISTS", ignoreCase = true)) { + return visitExprCallExists(node, context) + } + // Args + val args = node.args.map { visitExprCoerce(it, context) } + + // Check if function is actually coll_ + if (isCollAgg(node)) { + return callToCollAgg(id, node.setq, args) + } + + if (node.setq != null) { + error("Currently, only COLL_ may use set quantifiers.") + } + val op = rexOpCallUnresolved(id, args) + return rex(type, op) + } + + /** + * @return whether call is `COLL_`. + */ + private fun isCollAgg(node: ExprCall): Boolean { + val fn = node.function + val id = if (fn.next == null) { + // is not a qualified identifier chain + node.function.root + } else { + return false + } + return COLL_AGG_NAMES.contains(id.symbol.lowercase()) + } + + /** + * Converts COLL_ to the relevant function calls. For example: + * - `COLL_SUM(x)` becomes `coll_sum_all(x)` + * - `COLL_SUM(ALL x)` becomes `coll_sum_all(x)` + * - `COLL_SUM(DISTINCT x)` becomes `coll_sum_distinct(x)` + * + * It is assumed that the [id] has already been vetted by [isCollAgg]. + */ + private fun callToCollAgg(id: Identifier, setQuantifier: AstSetQuantifier?, args: List): Rex { + if (id.hasQualifier()) { + error("Qualified function calls are not currently supported.") + } + if (args.size != 1) { + error("Aggregate calls currently only support single arguments. Received ${args.size} arguments.") + } + val postfix = when (setQuantifier?.code()) { + AstSetQuantifier.DISTINCT -> "_distinct" + AstSetQuantifier.ALL -> "_all" + null -> "_all" + else -> error("Unexpected SetQuantifier type: $setQuantifier") + } + val newId = Identifier.regular(id.getIdentifier().getText() + postfix) + val op = Rex.Op.Call.Unresolved(newId, listOf(args[0])) + return Rex(ANY, op) + } + + private fun visitExprCallTupleUnion(node: ExprCall, context: Env): Rex { + val type = (STRUCT) + val args = node.args.map { visitExprCoerce(it, context) }.toMutableList() + val op = rexOpTupleUnion(args) + return rex(type, op) + } + + /** + * Assume that the node's identifier refers to EXISTS. + * TODO: This could be better suited as a dedicated node in the future. + */ + private fun visitExprCallExists(node: ExprCall, context: Env): Rex { + val type = (BOOL) + if (node.args.size != 1) { + error("EXISTS requires a single argument.") + } + val arg = visitExpr(node.args[0], context) + val op = rexOpCallUnresolved(V1AstToPlan.convert(node.function), listOf(arg)) + return rex(type, op) + } + + override fun visitExprCase(node: ExprCase, context: Env) = plan { + val type = (ANY) + val rex = when (node.expr) { + null -> null + else -> visitExprCoerce(node.expr!!, context) // match `rex + } + + // Converts AST CASE (x) WHEN y THEN z --> Plan CASE WHEN x = y THEN z + val id = Identifier.delimited("eq") + val createBranch: (Rex, Rex) -> Rex.Op.Case.Branch = { condition: Rex, result: Rex -> + val updatedCondition = when (rex) { + null -> condition + else -> rex(type, rexOpCallUnresolved(id, listOf(rex, condition))) + } + rexOpCaseBranch(updatedCondition, result) + } + + val branches = node.branches.map { + val branchCondition = visitExprCoerce(it.condition, context) + val branchRex = visitExprCoerce(it.expr, context) + createBranch(branchCondition, branchRex) + }.toMutableList() + + val defaultRex = when (val default = node.defaultExpr) { + null -> rex(type = ANY, op = rexOpLit(value = nullValue())) + else -> visitExprCoerce(default, context) + } + val op = rexOpCase(branches = branches, default = defaultRex) + rex(type, op) + } + + override fun visitExprArray(node: ExprArray, ctx: Env): Rex { + val values = node.values.map { visitExprCoerce(it, ctx) } + val op = rexOpCollection(values) + return rex(LIST, op) + } + + override fun visitExprBag(node: ExprBag, ctx: Env): Rex { + val values = node.values.map { visitExprCoerce(it, ctx) } + val op = rexOpCollection(values) + return rex(BAG, op) + } + + override fun visitExprStruct(node: ExprStruct, context: Env): Rex { + val type = (STRUCT) + val fields = node.fields.map { + val k = visitExprCoerce(it.name, context) + val v = visitExprCoerce(it.value, context) + rexOpStructField(k, v) + } + val op = rexOpStruct(fields) + return rex(type, op) + } + + // SPECIAL FORMS + + /** + * NOT? LIKE ( ESCAPE )? + */ + override fun visitExprLike(node: ExprLike, ctx: Env): Rex { + val type = BOOL + // Args + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = visitExprCoerce(node.pattern, ctx) + val arg2 = node.escape?.let { visitExprCoerce(it, ctx) } + // Call Variants + var call = when (arg2) { + null -> call("like", arg0, arg1) + else -> call("like_escape", arg0, arg1, arg2) + } + // NOT? + if (node.not == true) { + call = negate(call) + } + return rex(type, call) + } + + /** + * NOT? BETWEEN AND + */ + override fun visitExprBetween(node: ExprBetween, ctx: Env): Rex = plan { + val type = BOOL + // Args + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = visitExprCoerce(node.from, ctx) + val arg2 = visitExprCoerce(node.to, ctx) + // Call + var call = call("between", arg0, arg1, arg2) + // NOT? + if (node.not == true) { + call = negate(call) + } + rex(type, call) + } + + /** + * NOT? IN + * + * SQL Spec 1999 section 8.4 + * RVC IN IPV is equivalent to RVC = ANY IPV -> Quantified Comparison Predicate + * Which means: + * Let the expression be T in C, where C is [a1, ..., an] + * T in C is true iff T = a_x is true for any a_x in [a1, ...., an] + * T in C is false iff T = a_x is false for every a_x in [a1, ....., an ] or cardinality of the collection is 0. + * Otherwise, T in C is unknown. + * + */ + override fun visitExprInCollection(node: ExprInCollection, ctx: Env): Rex { + val type = BOOL + // Args + val arg0 = visitExprCoerce(node.lhs, ctx) + val arg1 = node.rhs.accept(this, ctx) // !! don't insert scalar subquery coercions + + // Call + var call = call("in_collection", arg0, arg1) + // NOT? + if (node.not == true) { + call = negate(call) + } + return rex(type, call) + } + + /** + * IS ? + */ + override fun visitExprIsType(node: ExprIsType, ctx: Env): Rex { + val type = BOOL + // arg + val arg0 = visitExprCoerce(node.value, ctx) + val targetType = node.type + var call = when (targetType.code()) { + // + DataType.NULL -> call("is_null", arg0) + DataType.MISSING -> call("is_missing", arg0) + // + // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT + DataType.CHARACTER, DataType.CHAR -> call("is_char", targetType.length.toRex(), arg0) + DataType.CHARACTER_VARYING, DataType.VARCHAR -> call("is_varchar", targetType.length.toRex(), arg0) + DataType.CLOB -> call("is_clob", arg0) + DataType.STRING -> call("is_string", targetType.length.toRex(), arg0) + DataType.SYMBOL -> call("is_symbol", arg0) + // + // TODO BINARY_LARGE_OBJECT + DataType.BLOB -> call("is_blob", arg0) + // + DataType.BIT -> call("is_bit", arg0) // TODO define in parser + DataType.BIT_VARYING -> call("is_bitVarying", arg0) // TODO define in parser + // - + DataType.NUMERIC -> call("is_numeric", targetType.precision.toRex(), targetType.scale.toRex(), arg0) + DataType.DEC, DataType.DECIMAL -> call("is_decimal", targetType.precision.toRex(), targetType.scale.toRex(), arg0) + DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> call("is_int64", arg0) + DataType.INT4, DataType.INTEGER4, DataType.INTEGER -> call("is_int32", arg0) + DataType.INT -> call("is_int", arg0) + DataType.INT2, DataType.SMALLINT -> call("is_int16", arg0) + DataType.TINYINT -> call("is_int8", arg0) // TODO define in parser + // - + DataType.FLOAT -> call("is_float32", arg0) + DataType.REAL -> call("is_real", arg0) + DataType.DOUBLE_PRECISION -> call("is_float64", arg0) + // + DataType.BOOLEAN, DataType.BOOL -> call("is_bool", arg0) + // + DataType.DATE -> call("is_date", arg0) + // TODO: DO we want to seperate with time zone vs without time zone into two different type in the plan? + // leave the parameterized type out for now until the above is answered + DataType.TIME -> call("is_time", arg0) + DataType.TIME_WITH_TIME_ZONE -> call("is_timeWithTz", arg0) + DataType.TIMESTAMP -> call("is_timestamp", arg0) + DataType.TIMESTAMP_WITH_TIME_ZONE -> call("is_timestampWithTz", arg0) + // + DataType.INTERVAL -> call("is_interval", arg0) // TODO define in parser + // + DataType.STRUCT, DataType.TUPLE -> call("is_struct", arg0) + // + DataType.LIST -> call("is_list", arg0) + DataType.BAG -> call("is_bag", arg0) + DataType.SEXP -> call("is_sexp", arg0) + // + DataType.USER_DEFINED -> call("is_custom", arg0) + else -> error("Unexpected DataType type: $targetType") + } + + if (node.not == true) { + call = negate(call) + } + + return rex(type, call) + } + + override fun visitExprCoalesce(node: ExprCoalesce, ctx: Env): Rex { + val type = ANY + val args = node.args.map { arg -> + visitExprCoerce(arg, ctx) + } + val op = rexOpCoalesce(args) + return rex(type, op) + } + + override fun visitExprNullIf(node: ExprNullIf, ctx: Env): Rex { + val type = ANY + val v1 = visitExprCoerce(node.v1, ctx) + val v2 = visitExprCoerce(node.v2, ctx) + val op = rexOpNullif(v1, v2) + return rex(type, op) + } + + /** + * SUBSTRING( (FROM (FOR )?)? ) + */ + override fun visitExprSubstring(node: ExprSubstring, ctx: Env): Rex { + val type = ANY + // Args + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = node.start?.let { visitExprCoerce(it, ctx) } ?: rex(INT, rexOpLit(int64Value(1))) + val arg2 = node.length?.let { visitExprCoerce(it, ctx) } + // Call Variants + val call = when (arg2) { + null -> call("substring", arg0, arg1) + else -> call("substring", arg0, arg1, arg2) + } + return rex(type, call) + } + + /** + * POSITION( IN ) + */ + override fun visitExprPosition(node: ExprPosition, ctx: Env): Rex { + val type = ANY + // Args + val arg0 = visitExprCoerce(node.lhs, ctx) + val arg1 = visitExprCoerce(node.rhs, ctx) + // Call + val call = call("position", arg0, arg1) + return rex(type, call) + } + + /** + * TRIM([LEADING|TRAILING|BOTH]? ( FROM)? ) + */ + override fun visitExprTrim(node: ExprTrim, ctx: Env): Rex { + val type = STRING + // Args + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = node.chars?.let { visitExprCoerce(it, ctx) } + // Call Variants + val call = when (node.trimSpec?.code()) { + TrimSpec.LEADING -> when (arg1) { + null -> call("trim_leading", arg0) + else -> call("trim_leading_chars", arg0, arg1) + } + TrimSpec.TRAILING -> when (arg1) { + null -> call("trim_trailing", arg0) + else -> call("trim_trailing_chars", arg0, arg1) + } + // TODO: We may want to add a trim_both for trim(BOTH FROM arg) + else -> when (arg1) { + null -> call("trim", arg0) + else -> call("trim_chars", arg0, arg1) + } + } + return rex(type, call) + } + + /** + * SQL Spec 1999: Section 6.18 + * + * ::= + * OVERLAY + * PLACING + * FROM + * [ FOR ] + * + * The is equivalent to: + * + * SUBSTRING ( CV FROM 1 FOR SP - 1 ) || RS || SUBSTRING ( CV FROM SP + SL ) + * + * Where CV is the first , + * SP is the + * RS is the second , + * SL is the if specified, otherwise it is char_length(RS). + */ + override fun visitExprOverlay(node: ExprOverlay, ctx: Env): Rex { + val cv = visitExprCoerce(node.value, ctx) + val sp = visitExprCoerce(node.from, ctx) + val rs = visitExprCoerce(node.placing, ctx) + val sl = node.forLength?.let { visitExprCoerce(it, ctx) } ?: rex(ANY, call("char_length", rs)) + val p1 = rex( + ANY, + call( + "substring", + cv, + rex(INT4, rexOpLit(int32Value(1))), + rex(ANY, call("minus", sp, rex(INT4, rexOpLit(int32Value(1))))) + ) + ) + val p2 = rex(ANY, call("concat", p1, rs)) + return rex( + ANY, + call( + "concat", + p2, + rex(ANY, call("substring", cv, rex(ANY, call("plus", sp, sl)))) + ) + ) + } + + override fun visitExprExtract(node: ExprExtract, ctx: Env): Rex { + val call = call("extract_${node.field.name().lowercase()}", visitExprCoerce(node.source, ctx)) + return rex(ANY, call) + } + + override fun visitExprCast(node: ExprCast, ctx: Env): Rex { + val type = visitType(node.asType) + val arg = visitExprCoerce(node.value, ctx) + return rex(ANY, rexOpCastUnresolved(type, arg)) + } + + private fun visitType(type: DataType): CompilerType { + return when (type.code()) { + // + DataType.NULL -> error("Casting to NULL is not supported.") + DataType.MISSING -> error("Casting to MISSING is not supported.") + // + // TODO CHAR_VARYING, CHARACTER_LARGE_OBJECT, CHAR_LARGE_OBJECT + DataType.CHARACTER, DataType.CHAR -> { + val length = type.length ?: 1 + assertGtZeroAndCreate(PType.Kind.CHAR, "length", length, PType::character) + } + DataType.CHARACTER_VARYING, DataType.VARCHAR -> { + val length = type.length ?: 1 + assertGtZeroAndCreate(PType.Kind.VARCHAR, "length", length, PType::varchar) + } + DataType.CLOB -> assertGtZeroAndCreate(PType.Kind.CLOB, "length", type.length ?: Int.MAX_VALUE, PType::clob) + DataType.STRING -> PType.string() + DataType.SYMBOL -> PType.symbol() + // + // TODO BINARY_LARGE_OBJECT + DataType.BLOB -> assertGtZeroAndCreate(PType.Kind.BLOB, "length", type.length ?: Int.MAX_VALUE, PType::blob) + // + DataType.BIT -> error("BIT is not supported yet.") + DataType.BIT_VARYING -> error("BIT VARYING is not supported yet.") + // - + DataType.NUMERIC -> { + val p = type.precision + val s = type.scale + when { + p == null && s == null -> PType.decimal() + p != null && s != null -> { + assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) + assertParamCompToZero(PType.Kind.NUMERIC, "scale", s, true) + if (s > p) { + throw TypeCheckException("Numeric scale cannot be greater than precision.") + } + PType.decimal(type.precision!!, type.scale!!) + } + p != null && s == null -> { + assertParamCompToZero(PType.Kind.NUMERIC, "precision", p, false) + PType.decimal(p, 0) + } + else -> error("Precision can never be null while scale is specified.") + } + } + DataType.DEC, DataType.DECIMAL -> { + val p = type.precision + val s = type.scale + when { + p == null && s == null -> PType.decimal() + p != null && s != null -> { + assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) + assertParamCompToZero(PType.Kind.DECIMAL, "scale", s, true) + if (s > p) { + throw TypeCheckException("Decimal scale cannot be greater than precision.") + } + PType.decimal(p, s) + } + p != null && s == null -> { + assertParamCompToZero(PType.Kind.DECIMAL, "precision", p, false) + PType.decimal(p, 0) + } + else -> error("Precision can never be null while scale is specified.") + } + } + DataType.BIGINT, DataType.INT8, DataType.INTEGER8 -> PType.bigint() + DataType.INT4, DataType.INTEGER4, DataType.INTEGER, DataType.INT -> PType.integer() + DataType.INT2, DataType.SMALLINT -> PType.smallint() + DataType.TINYINT -> PType.tinyint() // TODO define in parser + // - + DataType.FLOAT -> PType.real() + DataType.REAL -> PType.real() + DataType.DOUBLE_PRECISION -> PType.doublePrecision() + // + DataType.BOOL -> PType.bool() + // + DataType.DATE -> PType.date() + DataType.TIME -> assertGtEqZeroAndCreate(PType.Kind.TIME, "precision", type.precision ?: 0, PType::time) + DataType.TIME_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.Kind.TIMEZ, "precision", type.precision ?: 0, PType::timez) + DataType.TIMESTAMP -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMP, "precision", type.precision ?: 6, PType::timestamp) + DataType.TIMESTAMP_WITH_TIME_ZONE -> assertGtEqZeroAndCreate(PType.Kind.TIMESTAMPZ, "precision", type.precision ?: 6, PType::timestampz) + // + DataType.INTERVAL -> error("INTERVAL is not supported yet.") + // + DataType.STRUCT -> PType.struct() + DataType.TUPLE -> PType.struct() + // + DataType.LIST -> PType.array() + DataType.BAG -> PType.bag() + DataType.SEXP -> PType.sexp() + // + DataType.USER_DEFINED -> TODO("Custom type not supported ") + else -> error("Unsupported DataType type: $type") + }.toCType() + } + + private fun assertGtZeroAndCreate(type: PType.Kind, param: String, value: Int, create: (Int) -> PType): PType { + assertParamCompToZero(type, param, value, false) + return create.invoke(value) + } + + private fun assertGtEqZeroAndCreate(type: PType.Kind, param: String, value: Int, create: (Int) -> PType): PType { + assertParamCompToZero(type, param, value, true) + return create.invoke(value) + } + + /** + * @param allowZero when FALSE, this asserts that [value] > 0. If TRUE, this asserts that [value] >= 0. + */ + private fun assertParamCompToZero(type: PType.Kind, param: String, value: Int, allowZero: Boolean) { + val (result, compString) = when (allowZero) { + true -> (value >= 0) to "greater than" + false -> (value > 0) to "greater than or equal to" + } + if (!result) { + throw TypeCheckException("$type $param must be an integer value $compString 0.") + } + } + + override fun visitExprSessionAttribute(node: ExprSessionAttribute, ctx: Env): Rex { + val type = ANY + val fn = node.sessionAttribute.name().lowercase() + val call = call(fn) + return rex(type, call) + } + + override fun visitExprQuerySet(node: ExprQuerySet, context: Env): Rex = V1RelConverter.apply(node, context) + + // Helpers + + private fun negate(call: Rex.Op): Rex.Op.Call { + val id = Identifier.delimited("not") + val arg = rex(BOOL, call) + return rexOpCallUnresolved(id, listOf(arg)) + } + + /** + * Create a [Rex.Op.Call.Static] node which has a hidden unresolved Function. + * The purpose of having such hidden function is to prevent usage of generated function name in query text. + */ + private fun call(name: String, vararg args: Rex): Rex.Op.Call { + val id = Identifier.regular(name) + return rexOpCallUnresolved(id, args.toList()) + } + + private fun Int?.toRex() = rex(INT4, rexOpLit(int32Value(this))) + + private val ANY: CompilerType = CompilerType(PType.dynamic()) + private val BOOL: CompilerType = CompilerType(PType.bool()) + private val STRING: CompilerType = CompilerType(PType.string()) + private val STRUCT: CompilerType = CompilerType(PType.struct()) + private val BAG: CompilerType = CompilerType(PType.bag()) + private val LIST: CompilerType = CompilerType(PType.array()) + private val SEXP: CompilerType = CompilerType(PType.sexp()) + private val INT: CompilerType = CompilerType(PType.numeric()) + private val INT4: CompilerType = CompilerType(PType.integer()) + private val TIMESTAMP: CompilerType = CompilerType(PType.timestamp(6)) + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt new file mode 100644 index 000000000..497d83f78 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/V1SubstitutionVisitor.kt @@ -0,0 +1,15 @@ +package org.partiql.planner.internal.transforms + +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.AstRewriter +import org.partiql.ast.v1.expr.Expr + +internal object V1SubstitutionVisitor : AstRewriter>() { + override fun visitExpr(node: Expr, ctx: Map<*, AstNode>): AstNode { + val visited = super.visitExpr(node, ctx) + if (ctx.containsKey(visited)) { + return ctx[visited]!! + } + return visited + } +} diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt index e29a0e493..f8ca7fbe8 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt @@ -5,7 +5,7 @@ import org.junit.jupiter.api.DynamicContainer.dynamicContainer import org.junit.jupiter.api.DynamicNode import org.junit.jupiter.api.DynamicTest import org.junit.jupiter.api.TestFactory -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Plan import org.partiql.planner.internal.TestCatalog import org.partiql.planner.test.PartiQLTest @@ -75,7 +75,7 @@ class PlanTest { ) .namespace("SCHEMA") .build() - val ast = PartiQLParser.standard().parse(test.statement).root + val ast = V1PartiQLParser.standard().parse(test.statement).root val planner = PartiQLPlanner.builder().signal(isSignalMode).build() planner.plan(ast, session) } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt index 40ce8c89c..8842518e6 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerPErrorReportingTests.kt @@ -2,8 +2,8 @@ package org.partiql.planner import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource -import org.partiql.ast.Statement -import org.partiql.parser.PartiQLParserBuilder +import org.partiql.ast.v1.Statement +import org.partiql.parser.V1PartiQLParserBuilder import org.partiql.plan.Operation import org.partiql.planner.internal.typer.CompilerType import org.partiql.planner.internal.typer.PlanTyper.Companion.toCType @@ -19,7 +19,6 @@ import org.partiql.types.PType import org.partiql.types.StaticType import org.partiql.types.StructType import org.partiql.types.TupleConstraint -import java.lang.AssertionError import kotlin.test.assertEquals internal class PlannerPErrorReportingTests { @@ -42,7 +41,7 @@ internal class PlannerPErrorReportingTests { .catalogs(catalog) .build() - private val parser = PartiQLParserBuilder().build() + private val parser = V1PartiQLParserBuilder().build() private val statement: ((String) -> Statement) = { query -> parser.parse(query).root diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt index f91cff9d6..b31f3ce80 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/exclude/SubsumptionTest.kt @@ -7,7 +7,7 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Exclusion import org.partiql.plan.Operation import org.partiql.plan.builder.PlanFactory @@ -26,7 +26,7 @@ class SubsumptionTest { companion object { private val planner = PartiQLPlanner.standard() - private val parser = PartiQLParser.standard() + private val parser = V1PartiQLParser.standard() private val catalog = MemoryCatalog.builder().name("default").build() } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt index f91349bd9..6d26fcaf2 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt @@ -2,7 +2,7 @@ package org.partiql.planner.internal.typer import org.junit.jupiter.api.DynamicContainer import org.junit.jupiter.api.DynamicTest -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Operation import org.partiql.planner.PartiQLPlanner import org.partiql.planner.test.PartiQLTest @@ -38,7 +38,7 @@ abstract class PartiQLTyperTestBase { companion object { - public val parser = PartiQLParser.standard() + public val parser = V1PartiQLParser.standard() public val planner = PartiQLPlanner.standard() internal val session: ((String, Catalog) -> Session) = { catalog, metadata -> diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index b4253da0d..a5340499b 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -12,7 +12,7 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.MethodSource -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.PErrors import org.partiql.planner.internal.TestCatalog @@ -125,7 +125,7 @@ internal class PlanTyperTestsPorted { companion object { - private val parser = PartiQLParser.standard() + private val parser = V1PartiQLParser.standard() private val planner = PartiQLPlanner.builder().signal().build() private fun assertProblemExists(problem: PError) = ProblemHandler { problems, _ -> diff --git a/partiql-planner/src/test/resources/outputs/basics/select.sql b/partiql-planner/src/test/resources/outputs/basics/select.sql index 97b2079cc..7a5ecd70e 100644 --- a/partiql-planner/src/test/resources/outputs/basics/select.sql +++ b/partiql-planner/src/test/resources/outputs/basics/select.sql @@ -41,10 +41,10 @@ SELECT CURRENT_DATE AS "CURRENT_DATE" FROM "default"."SCHEMA"."T" AS "T"; SELECT DATE_DIFF(DAY, CURRENT_DATE, CURRENT_DATE) AS "_1" FROM "default"."SCHEMA"."T" AS "T"; --#[select-14] -SELECT DATE_ADD(DAY, 5, CURRENT_DATE) AS "_1" FROM "default"."SCHEMA"."T" AS "T" +SELECT DATE_ADD(DAY, 5, CURRENT_DATE) AS "_1" FROM "default"."SCHEMA"."T" AS "T"; --#[select-15] -SELECT DATE_ADD(DAY, -5, CURRENT_DATE) AS "_1" FROM "default"."SCHEMA"."T" AS "T" +SELECT DATE_ADD(DAY, -5, CURRENT_DATE) AS "_1" FROM "default"."SCHEMA"."T" AS "T"; --#[select-16] -SELECT "t"['a'] AS "a" FROM "default"."SCHEMA"."T" AS "t"; \ No newline at end of file +SELECT "t"['a'] AS "a" FROM "default"."SCHEMA"."T" AS "t"; diff --git a/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt b/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt index 4ca4287f5..865c7cebc 100644 --- a/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt +++ b/test/partiql-randomized-tests/src/test/kotlin/org/partiql/lang/randomized/eval/Utils.kt @@ -1,7 +1,7 @@ package org.partiql.lang.randomized.eval import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.planner.PartiQLPlanner import org.partiql.spi.catalog.Catalog import org.partiql.spi.catalog.Session @@ -22,7 +22,7 @@ fun runEvaluatorTestCase( @OptIn(PartiQLValueExperimental::class) private fun execute(query: String): PartiQLValue { - val parser = PartiQLParser.builder().build() + val parser = V1PartiQLParser.builder().build() val planner = PartiQLPlanner.builder().build() val catalog = object : Catalog { override fun getName(): String = "default" diff --git a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt index db58f2a8f..4f6fbb19d 100644 --- a/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt +++ b/test/partiql-tests-runner/src/test/kotlin/org/partiql/runner/executor/EvalExecutor.kt @@ -10,7 +10,7 @@ import com.amazon.ionelement.api.toIonValue import org.partiql.eval.Mode import org.partiql.eval.Statement import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParser +import org.partiql.parser.V1PartiQLParser import org.partiql.plan.Operation.Query import org.partiql.planner.PartiQLPlanner import org.partiql.plugins.memory.MemoryCatalog @@ -141,7 +141,7 @@ class EvalExecutor( companion object { val compiler = PartiQLCompiler.standard() - val parser = PartiQLParser.standard() + val parser = V1PartiQLParser.standard() val planner = PartiQLPlanner.standard() // TODO REPLACE WITH DATUM COMPARATOR val comparator = PartiQLValue.comparator()