From e2b12e76136c526f754664edfb2ab5a4b3ce8bdb Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Tue, 10 Oct 2023 12:20:55 -0700 Subject: [PATCH] Adds planning/typing support of SELECT * --- .../org/partiql/ast/normalize/Normalize.kt | 3 +- .../partiql/ast/normalize/NormalizeSelect.kt | 252 +++++++++++++ .../ast/normalize/NormalizeSelectList.kt | 53 --- .../PartiQLSchemaInferencerTests.kt | 357 ++++++++++++++++-- .../resources/catalogs/aws/ddb/persons.ion | 46 +++ .../src/main/resources/partiql_plan_0_1.ion | 23 +- .../partiql/planner/PartiQLPlannerDefault.kt | 16 +- .../planner/transforms/RelConverter.kt | 132 ++----- .../planner/transforms/RexConverter.kt | 27 +- .../partiql/planner/typer/ConstantFolder.kt | 239 ++++++++++++ .../org/partiql/planner/typer/PlanTyper.kt | 186 ++++++--- .../org/partiql/planner/typer/RexReplacer.kt | 42 +++ .../org/partiql/planner/test/Parsing.kt | 3 +- .../catalogs/default/pql/employer.ion | 30 ++ .../resources/catalogs/default/pql/person.ion | 33 ++ .../testFixtures/resources/tests/suite_00.ion | 89 ++++- 16 files changed, 1253 insertions(+), 278 deletions(-) create mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelect.kt delete mode 100644 partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt create mode 100644 partiql-lang/src/test/resources/catalogs/aws/ddb/persons.ion create mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/typer/ConstantFolder.kt create mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/typer/RexReplacer.kt create mode 100644 partiql-planner/src/testFixtures/resources/catalogs/default/pql/employer.ion create mode 100644 partiql-planner/src/testFixtures/resources/catalogs/default/pql/person.ion diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt index 66d8763487..6276c6b1c2 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/Normalize.kt @@ -8,8 +8,7 @@ import org.partiql.ast.Statement public fun Statement.normalize(): Statement { // could be a fold, but this is nice for setting breakpoints var ast = this - ast = NormalizeSelectList.apply(ast) ast = NormalizeFromSource.apply(ast) - ast = NormalizeSelectStar.apply(ast) + ast = NormalizeSelect.apply(ast) return ast } diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelect.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelect.kt new file mode 100644 index 0000000000..ed3278e067 --- /dev/null +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelect.kt @@ -0,0 +1,252 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.ast.normalize + +import org.partiql.ast.AstNode +import org.partiql.ast.Expr +import org.partiql.ast.From +import org.partiql.ast.Identifier +import org.partiql.ast.Select +import org.partiql.ast.Statement +import org.partiql.ast.builder.AstBuilder +import org.partiql.ast.builder.ast +import org.partiql.ast.helpers.toBinder +import org.partiql.ast.util.AstRewriter +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.stringValue + +/** + * Converts SQL-style SELECT to PartiQL SELECT VALUE. + * - If there is a PROJECT ALL, we use the TUPLEUNION. + * - If there is NOT a PROJECT ALL, we use a literal struct. + * + * Here are some example of rewrites: + * + * ``` + * SELECT * + * FROM + * A AS x, + * B AS y AT i + * ``` + * gets rewritten to: + * ``` + * SELECT VALUE TUPLEUNION( + * CASE WHEN x IS STRUCT THEN x ELSE { '_1': x }, + * CASE WHEN y IS STRUCT THEN y ELSE { '_2': y }, + * { 'i': i } + * ) FROM A AS x, B AS y AT i + * ``` + * + * ``` + * SELECT x.*, x.a FROM A AS x + * ``` + * gets rewritten to: + * ``` + * SELECT VALUE TUPLEUNION( + * CASE WHEN x IS STRUCT THEN x ELSE { '_1': x }, + * { 'a': x.a } + * ) FROM A AS x + * ``` + * + * ``` + * SELECT x.a FROM A AS x + * ``` + * gets rewritten to: + * ``` + * SELECT VALUE { + * 'a': x.a + * } FROM A AS x + * ``` + * + * TODO: GROUP BY + * TODO: LET + * + * Requires [NormalizeFromSource]. + */ +internal object NormalizeSelect : AstPass { + + override fun apply(statement: Statement): Statement = Visitor.visitStatement(statement, 0) as Statement + + private object Visitor : AstRewriter() { + + override fun visitExprSFW(node: Expr.SFW, ctx: Int) = ast { + val sfw = super.visitExprSFW(node, ctx) as Expr.SFW + when (val select = sfw.select) { + is Select.Star -> sfw.copy(select = visitSelectAll(select, sfw.from)) + else -> sfw + } + } + + override fun visitSelectProject(node: Select.Project, ctx: Int): AstNode = ast { + val visitedNode = super.visitSelectProject(node, ctx) as? Select.Project + ?: error("VisitSelectProject should have returned a Select.Project") + return@ast when (node.items.any { it is Select.Project.Item.All }) { + false -> visitSelectProjectWithoutProjectAll(visitedNode) + true -> visitSelectProjectWithProjectAll(visitedNode) + } + } + + override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: Int) = ast { + val expr = visitExpr(node.expr, 0) as Expr + val alias = when (node.asAlias) { + null -> expr.toBinder(ctx) + else -> node.asAlias + } + if (expr != node.expr || alias != node.asAlias) { + selectProjectItemExpression(expr, alias) + } else { + node + } + } + + // Helpers + + /** + * We need to call this from [visitExprSFW] and not override [visitSelectStar] because we need access to the + * [From] aliases. + * + * Note: We assume that [select] and [from] have already been visited. + */ + private fun visitSelectAll(select: Select.Star, from: From): Select.Value = ast { + val tupleUnionArgs = from.aliases().flatMapIndexed { i, binding -> + val asAlias = binding.first + val atAlias = binding.second + val byAlias = binding.third + val atAliasItem = atAlias?.simple()?.let { + val alias = it.asAlias ?: error("The AT alias should be present. This wasn't normalized.") + buildSimpleStruct(it.expr, alias.symbol) + } + val byAliasItem = byAlias?.simple()?.let { + val alias = it.asAlias ?: error("The AT alias should be present. This wasn't normalized.") + buildSimpleStruct(it.expr, alias.symbol) + } + listOfNotNull( + buildCaseWhenStruct(asAlias.star(i).expr, i), + atAliasItem, + byAliasItem + ) + } + selectValue { + constructor = exprCall { + function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE) + args.addAll(tupleUnionArgs) + } + setq = select.setq + } + } + + private fun visitSelectProjectWithProjectAll(node: Select.Project): AstNode = ast { + val tupleUnionArgs = node.items.mapIndexed { index, item -> + when (item) { + is Select.Project.Item.All -> buildCaseWhenStruct(item.expr, index) + is Select.Project.Item.Expression -> buildSimpleStruct( + item.expr, + item.asAlias?.symbol + ?: error("The alias should've been here. This AST is not normalized.") + ) + } + } + selectValue { + setq = node.setq + constructor = exprCall { + function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE) + args.addAll(tupleUnionArgs) + } + } + } + + @OptIn(PartiQLValueExperimental::class) + private fun visitSelectProjectWithoutProjectAll(node: Select.Project): AstNode = ast { + val structFields = node.items.map { item -> + val itemExpr = item as? Select.Project.Item.Expression ?: error("Expected the projection to be an expression.") + exprStructField( + name = exprLit(stringValue(itemExpr.asAlias?.symbol!!)), + value = item.expr + ) + } + selectValue { + setq = node.setq + constructor = exprStruct { + fields.addAll(structFields) + } + } + } + + @OptIn(PartiQLValueExperimental::class) + private fun buildCaseWhenStruct(expr: Expr, index: Int): Expr.Case { + return ast { + exprCase { + branches.add( + exprCaseBranch( + condition = exprIsType(expr, typeStruct()), + expr = expr + ) + ) + default = buildSimpleStruct(expr, col(index)) + exprStruct { + fields.add( + exprStructField( + name = exprLit(stringValue(index.toString())), + value = expr + ) + ) + } + } + } + } + + @OptIn(PartiQLValueExperimental::class) + private fun buildSimpleStruct(expr: Expr, name: String): Expr.Struct { + return ast { + exprStruct { + fields.add( + exprStructField( + name = exprLit(stringValue(name)), + value = expr + ) + ) + } + } + } + + private fun From.aliases(): List> = when (this) { + is From.Join -> lhs.aliases() + rhs.aliases() + is From.Value -> { + val asAlias = asAlias?.symbol ?: error("AST not normalized, missing asAlias on FROM source.") + val atAlias = atAlias?.symbol + val byAlias = byAlias?.symbol + listOf(Triple(asAlias, atAlias, byAlias)) + } + } + + private val col = { index: Int -> "_${index + 1}" } + + // t -> t.* AS _i + private fun String.star(i: Int) = ast { + val expr = exprVar(id(this@star), Expr.Var.Scope.DEFAULT) + val alias = expr.toBinder(i) + selectProjectItemExpression(expr, alias) + } + + // t -> t AS t + private fun String.simple() = ast { + val expr = exprVar(id(this@simple), Expr.Var.Scope.DEFAULT) + val alias = id(this@simple) + selectProjectItemExpression(expr, alias) + } + + private fun AstBuilder.id(symbol: String) = identifierSymbol(symbol, Identifier.CaseSensitivity.INSENSITIVE) + } +} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt deleted file mode 100644 index 238b77bafb..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt +++ /dev/null @@ -1,53 +0,0 @@ -package org.partiql.ast.normalize - -import org.partiql.ast.Expr -import org.partiql.ast.Select -import org.partiql.ast.Statement -import org.partiql.ast.builder.ast -import org.partiql.ast.helpers.toBinder -import org.partiql.ast.util.AstRewriter - -/** - * Adds an `as` alias to every select-list item. - * - * - [org.partiql.ast.helpers.toBinder] - * - https://partiql.org/assets/PartiQL-Specification.pdf#page=28 - * - https://web.cecs.pdx.edu/~len/sql1999.pdf#page=287 - */ -internal object NormalizeSelectList : AstPass { - - override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement - - private object Visitor : AstRewriter() { - - override fun visitSelectProject(node: Select.Project, ctx: Int) = ast { - if (node.items.isEmpty()) { - return@ast node - } - var diff = false - val transformed = ArrayList(node.items.size) - node.items.forEachIndexed { i, n -> - val item = visitSelectProjectItem(n, i) as Select.Project.Item - if (item !== n) diff = true - transformed.add(item) - } - // We don't want to create a new list unless we have to, as to not trigger further rewrites up the tree. - if (diff) selectProject(transformed) else node - } - - override fun visitSelectProjectItemAll(node: Select.Project.Item.All, ctx: Int) = node.copy() - - override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: Int) = ast { - val expr = visitExpr(node.expr, 0) as Expr - val alias = when (node.asAlias) { - null -> expr.toBinder(ctx) - else -> node.asAlias - } - if (expr != node.expr || alias != node.asAlias) { - selectProjectItemExpression(expr, alias) - } else { - node - } - } - } -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt index 827fd006b3..17e8b48657 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt @@ -27,6 +27,7 @@ import org.partiql.plugins.local.LocalPlugin import org.partiql.types.AnyOfType import org.partiql.types.AnyType import org.partiql.types.BagType +import org.partiql.types.IntType import org.partiql.types.ListType import org.partiql.types.SexpType import org.partiql.types.StaticType @@ -36,6 +37,7 @@ import org.partiql.types.StaticType.Companion.MISSING import org.partiql.types.StaticType.Companion.NULL import org.partiql.types.StaticType.Companion.STRING import org.partiql.types.StaticType.Companion.unionOf +import org.partiql.types.StringType import org.partiql.types.StructType import org.partiql.types.TupleConstraint import java.time.Instant @@ -93,6 +95,11 @@ class PartiQLSchemaInferencerTests { @Execution(ExecutionMode.CONCURRENT) fun testOrderBy(tc: TestCase) = runTest(tc) + @ParameterizedTest + @MethodSource("tupleUnionCases") + @Execution(ExecutionMode.CONCURRENT) + fun testTupleUnion(tc: TestCase) = runTest(tc) + companion object { private val root = this::class.java.getResource("/catalogs")!!.toURI().toPath().pathString @@ -2119,12 +2126,277 @@ class PartiQLSchemaInferencerTests { query = "SELECT * FROM pets ORDER BY breed", expected = TABLE_AWS_DDB_PETS_LIST ), - SuccessTestCase( + ErrorTestCase( name = "ORDER BY str", catalog = CATALOG_AWS, catalogPath = listOf("ddb"), query = "SELECT * FROM pets ORDER BY unknown_col", - expected = TABLE_AWS_DDB_PETS_LIST + expected = TABLE_AWS_DDB_PETS_LIST, + problemHandler = assertProblemExists { + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UndefinedVariable("unknown_col", false) + ) + } + ), + ) + + @JvmStatic + fun tupleUnionCases() = listOf( + SuccessTestCase( + name = "Empty Tuple Union", + query = "TUPLEUNION()", + expected = StructType( + fields = emptyMap(), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true) + ) + ) + ), + SuccessTestCase( + name = "Tuple Union with Literal Struct", + query = "TUPLEUNION({ 'a': 1, 'b': 'hello' })", + expected = StructType( + fields = mapOf( + "a" to IntType(), + "b" to StringType() + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + ) + ), + ), + // TODO: Is this right? Given the lack of ordering, can we say that there are 2 fields (but we don't know + // whether one will definitely be accessed? + SuccessTestCase( + name = "Tuple Union with Literal Struct AND Duplicates", + query = "TUPLEUNION({ 'a': 1, 'a': 'hello' })", + expected = StructType( + fields = listOf( + StructType.Field("a", unionOf(INT, STRING)), + StructType.Field("a", unionOf(INT, STRING)), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(false), + ) + ), + ), + SuccessTestCase( + name = "Tuple Union with Nested Struct", + query = """ + SELECT VALUE TUPLEUNION( + t.a + ) FROM << + { 'a': { 'b': 1 } } + >> AS t + """, + expected = BagType( + StructType( + fields = listOf( + StructType.Field("b", INT), + ), + contentClosed = true, + // TODO: This shouldn't be ordered. However, this doesn't come from the TUPLEUNION. It is + // coming from the RexOpSelect. + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ), + ), + SuccessTestCase( + name = "Tuple Union with Heterogeneous Data", + query = """ + SELECT VALUE TUPLEUNION( + t.a + ) FROM << + { 'a': { 'b': 1 } }, + { 'a': 1 } + >> AS t + """, + expected = BagType( + unionOf( + MISSING, + StructType( + fields = listOf( + StructType.Field("b", INT), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + ) + ) + ) + ), + ), + SuccessTestCase( + name = "Tuple Union with Heterogeneous Data (2)", + query = """ + SELECT VALUE TUPLEUNION( + t.a + ) FROM << + { 'a': { 'b': 1 } }, + { 'a': { 'b': 'hello' } }, + { 'a': NULL }, + { 'a': 4.5 }, + { } + >> AS t + """, + expected = BagType( + unionOf( + NULL, + MISSING, + StructType( + fields = listOf( + StructType.Field("b", INT), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + ) + ), + StructType( + fields = listOf( + StructType.Field("b", STRING), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + ) + ) + ) + ), + ), + SuccessTestCase( + name = "Tuple Union with Heterogeneous Data (3)", + query = """ + SELECT VALUE TUPLEUNION( + p.name + ) FROM aws.ddb.persons AS p + """, + expected = BagType( + unionOf( + MISSING, + StructType( + fields = listOf( + StructType.Field("first", STRING), + StructType.Field("last", STRING), + ), + contentClosed = false, + constraints = setOf( + TupleConstraint.Open(true), + TupleConstraint.UniqueAttrs(false), + ) + ), + StructType( + fields = listOf( + StructType.Field("full_name", STRING), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ), + ) + ), + ), + SuccessTestCase( + name = "Complex Tuple Union with Heterogeneous Data", + query = """ + SELECT VALUE TUPLEUNION( + p.name, + p.name + ) FROM aws.ddb.persons AS p + """, + expected = BagType( + unionOf( + MISSING, + StructType( + fields = listOf( + StructType.Field("first", STRING), + StructType.Field("last", STRING), + ), + contentClosed = false, + constraints = setOf( + TupleConstraint.Open(true), + TupleConstraint.UniqueAttrs(false), + ) + ), + StructType( + fields = listOf( + StructType.Field("full_name", STRING), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ), + StructType( + fields = listOf( + StructType.Field("first", STRING), + StructType.Field("last", STRING), + StructType.Field("first", STRING), + StructType.Field("last", STRING), + ), + contentClosed = false, + constraints = setOf( + TupleConstraint.Open(true), + TupleConstraint.UniqueAttrs(false), + ) + ), + StructType( + fields = listOf( + StructType.Field("first", STRING), + StructType.Field("last", STRING), + StructType.Field("full_name", STRING), + ), + contentClosed = false, + constraints = setOf( + TupleConstraint.Open(true), + TupleConstraint.UniqueAttrs(false), + ) + ), + StructType( + fields = listOf( + StructType.Field("full_name", STRING), + StructType.Field("first", STRING), + StructType.Field("last", STRING), + ), + contentClosed = false, + constraints = setOf( + TupleConstraint.Open(true), + TupleConstraint.UniqueAttrs(false), + ) + ), + StructType( + fields = listOf( + StructType.Field("full_name", STRING), + StructType.Field("full_name", STRING), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(false), + TupleConstraint.Ordered + ) + ), + ) + ), ), ) } @@ -2176,14 +2448,25 @@ class PartiQLSchemaInferencerTests { name = "Pets should not be accessible #1", query = "SELECT * FROM pets", expected = BagType( - StructType( - fields = mapOf("pets" to StaticType.ANY), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) + unionOf( + StructType( + fields = emptyMap(), + contentClosed = false, + constraints = setOf( + TupleConstraint.Open(true), + TupleConstraint.UniqueAttrs(false), + ) + ), + StructType( + fields = mapOf( + "_1" to StaticType.ANY + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + ) + ), ) ), problemHandler = assertProblemExists { @@ -2198,14 +2481,25 @@ class PartiQLSchemaInferencerTests { catalog = CATALOG_AWS, query = "SELECT * FROM pets", expected = BagType( - StructType( - fields = mapOf("pets" to StaticType.ANY), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) + unionOf( + StructType( + fields = emptyMap(), + contentClosed = false, + constraints = setOf( + TupleConstraint.Open(true), + TupleConstraint.UniqueAttrs(false), + ) + ), + StructType( + fields = mapOf( + "_1" to StaticType.ANY + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + ) + ), ) ), problemHandler = assertProblemExists { @@ -2254,14 +2548,25 @@ class PartiQLSchemaInferencerTests { name = "Test #7", query = "SELECT * FROM ddb.pets", expected = BagType( - StructType( - fields = mapOf("pets" to StaticType.ANY), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) + unionOf( + StructType( + fields = emptyMap(), + contentClosed = false, + constraints = setOf( + TupleConstraint.Open(true), + TupleConstraint.UniqueAttrs(false), + ) + ), + StructType( + fields = mapOf( + "_1" to StaticType.ANY + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + ) + ), ) ), problemHandler = assertProblemExists { diff --git a/partiql-lang/src/test/resources/catalogs/aws/ddb/persons.ion b/partiql-lang/src/test/resources/catalogs/aws/ddb/persons.ion new file mode 100644 index 0000000000..7b0d9ba582 --- /dev/null +++ b/partiql-lang/src/test/resources/catalogs/aws/ddb/persons.ion @@ -0,0 +1,46 @@ +{ + type: "bag", + items: { + type: "struct", + constraints: [ + closed, + unique, + ordered + ], + fields: [ + { + name: "name", + type: [ + "string", + { + type: "struct", + fields: [ + { + name: "first", + type: "string" + }, + { + name: "last", + type: "string" + } + ] + }, + { + type: "struct", + constraints: [ + closed, + unique, + ordered + ], + fields: [ + { + name: "full_name", + type: "string" + }, + ] + }, + ] + }, + ] + } +} diff --git a/partiql-plan/src/main/resources/partiql_plan_0_1.ion b/partiql-plan/src/main/resources/partiql_plan_0_1.ion index 920df2587d..0cdab738e6 100644 --- a/partiql-plan/src/main/resources/partiql_plan_0_1.ion +++ b/partiql-plan/src/main/resources/partiql_plan_0_1.ion @@ -159,28 +159,9 @@ rex::{ // CASE WHEN v3 IS TUPLE THEN v3 ELSE {'_2': v3} END // ) // - // tuple_union::{ - // args: [ - // spread('_1', v1), - // struct('a', e2), - // spread('_2', v3), - // ] - // } - // + // Tuple Union Function Signature: (Array) -> Struct tuple_union::{ - args: list::[arg], - _: [ - arg::[ - struct::{ - k: string, - v: rex, - }, - spread::{ - k: string, - v: rex, - }, - ], - ], + args: list::[rex], }, err::{ diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt index c7354bce65..90990dd592 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt @@ -6,6 +6,7 @@ import org.partiql.errors.ProblemCallback import org.partiql.plan.PartiQLVersion import org.partiql.plan.partiQLPlan import org.partiql.planner.transforms.AstToPlan +import org.partiql.planner.typer.ConstantFolder import org.partiql.planner.typer.PlanTyper import org.partiql.spi.Plugin @@ -34,17 +35,28 @@ internal class PartiQLPlannerDefault( // 3. Resolve variables val typer = PlanTyper(env, onProblem) + var planStmt = typer.resolve(root) + + // 4. Fold and re-resolve + planStmt = fold(planStmt, env, onProblem) + var plan = partiQLPlan( version = version, globals = env.globals, - statement = typer.resolve(root), + statement = planStmt, ) - // 4. Apply all passes + // 5. Apply all passes for (pass in passes) { plan = pass.apply(plan, onProblem) } return PartiQLPlanner.Result(plan, emptyList()) } + + private fun fold(stmt: org.partiql.plan.Statement, env: Env, onProblem: ProblemCallback): org.partiql.plan.Statement { + val foldedPlan = ConstantFolder.fold(stmt) + val typer = PlanTyper(env, onProblem) + return typer.resolve(foldedPlan) + } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt index 418762424e..153430f3ae 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt @@ -30,6 +30,7 @@ import org.partiql.ast.util.AstRewriter import org.partiql.ast.visitor.AstBaseVisitor import org.partiql.plan.Rel import org.partiql.plan.Rex +import org.partiql.plan.builder.plan import org.partiql.plan.fnUnresolved import org.partiql.plan.rel import org.partiql.plan.relBinding @@ -57,20 +58,10 @@ import org.partiql.plan.relOpUnpivot import org.partiql.plan.relType import org.partiql.plan.rex import org.partiql.plan.rexOpLit -import org.partiql.plan.rexOpPath -import org.partiql.plan.rexOpPivot -import org.partiql.plan.rexOpSelect -import org.partiql.plan.rexOpStruct -import org.partiql.plan.rexOpStructField -import org.partiql.plan.rexOpTupleUnion -import org.partiql.plan.rexOpTupleUnionArgSpread -import org.partiql.plan.rexOpTupleUnionArgStruct -import org.partiql.plan.rexOpVarResolved import org.partiql.planner.Env import org.partiql.types.StaticType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.boolValue -import org.partiql.value.stringValue /** * Lexically scoped state for use in translating an individual SELECT statement. @@ -80,11 +71,14 @@ internal object RelConverter { // IGNORE — so we don't have to non-null assert on operator inputs private val nil = rel(relType(emptyList(), emptySet()), relOpErr("nil")) + private var uniqueNumber: Long = 0L + private fun getNextUniqueName(): String = "\$__${uniqueNumber++}" + /** * Here we convert an SFW to composed [Rel]s, then apply the appropriate relation-value projection to get a [Rex]. */ - internal fun apply(sfw: Expr.SFW, env: Env): Rex { - var rel = sfw.accept(ToRel(env), nil) + internal fun apply(sfw: Expr.SFW, env: Env): Rex = plan { + val rel = sfw.accept(ToRel(env), nil) val rex = when (val projection = sfw.select) { // PIVOT ... FROM is Select.Pivot -> { @@ -96,7 +90,11 @@ internal object RelConverter { } // SELECT VALUE ... FROM is Select.Value -> { - val constructor = projection.constructor.toRex(env) + val projectionOp = rel.op as? Rel.Op.Project ?: error("SELECT VALUE should have a PROJECT underneath") + assert(projectionOp.projections.size == 1) { + "Expected SELECT VALUE projection to have a single binding. However, it looked like ${projectionOp.projections}" + } + val constructor = rex(StaticType.ANY, rexOpVarResolved(0)) val op = rexOpSelect(constructor, rel) val type = when (rel.type.props.contains(Rel.Prop.ORDERED)) { true -> (StaticType.LIST) @@ -110,17 +108,10 @@ internal object RelConverter { } // SELECT ... FROM is Select.Project -> { - val (constructor, newRel) = deriveConstructor(rel) - rel = newRel - val op = rexOpSelect(constructor, rel) - val type = when (rel.type.props.contains(Rel.Prop.ORDERED)) { - true -> (StaticType.LIST) - else -> (StaticType.BAG) - } - rex(type, op) + throw IllegalArgumentException("AST not normalized") } } - return rex + rex } /** @@ -128,85 +119,6 @@ internal object RelConverter { */ private fun Expr.toRex(env: Env): Rex = RexConverter.apply(this, env) - /** - * Derives the appropriate SELECT VALUE constructor given the Rel projection - removing any UNPIVOT path steps. - */ - private fun deriveConstructor(rel: Rel): Pair { - val op = rel.op - if (op !is Rel.Op.Project) { - return defaultConstructor(rel.type.schema) to rel - } - val hasProjectAll = op.projections.any { it.isProjectAll() } - return when (hasProjectAll) { - true -> tupleUnionConstructor(op, rel.type) - else -> defaultConstructor(rel.type.schema) to rel - } - } - - /** - * Produces the default constructor to generalize a SQL SELECT to a SELECT VALUE. - * - * See https://partiql.org/dql/select.html#sql-select - */ - @OptIn(PartiQLValueExperimental::class) - private fun defaultConstructor(schema: List): Rex { - val fields = schema.mapIndexed { i, b -> - val k = rex(StaticType.STRING, rexOpLit(stringValue(b.name))) - val v = rex(b.type, rexOpVarResolved(i)) - rexOpStructField(k, v) - } - val op = rexOpStruct(fields) - return rex(StaticType.STRUCT, op) - } - - /** - * Produces the `TUPLEUNION` constructor defined in Section 6.3.2 `SQL's SELECT *`. - * - * See https://partiql.org/assets/PartiQL-Specification.pdf#page=28 - */ - private fun tupleUnionConstructor(op: Rel.Op.Project, type: Rel.Type): Pair { - val projections = mutableListOf() - val args = op.projections.mapIndexed { i, item -> - val binding = type.schema[i] - val k = binding.name - val v = rex(binding.type, rexOpVarResolved(i)) - when (item.isProjectAll()) { - true -> { - projections.add(item.removeUnpivot()) - rexOpTupleUnionArgSpread(k, v) - } - else -> { - projections.add(item) - rexOpTupleUnionArgStruct(k, v) - } - } - } - val constructor = rex(StaticType.STRUCT, rexOpTupleUnion(args)) - val rel = rel(type, relOpProject(op.input, projections)) - return constructor to rel - } - - private fun Rex.isProjectAll(): Boolean { - return (op is Rex.Op.Path && (op as Rex.Op.Path).steps.last() is Rex.Op.Path.Step.Unpivot) - } - - private fun Rex.removeUnpivot(): Rex { - val rex = this@removeUnpivot - val path = op - if (path is Rex.Op.Path) { - if (path.steps.last() is Rex.Op.Path.Step.Unpivot) { - val newRoot = path.root - val newSteps = path.steps.dropLast(1) - return when (newSteps.isEmpty()) { - true -> newRoot - else -> rex(rex.type, rexOpPath(newRoot, newSteps)) - } - } - } - // skip, should be unreachable - return rex - } - @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE", "LocalVariableName") private class ToRel(private val env: Env) : AstBaseVisitor() { @@ -235,6 +147,7 @@ internal object RelConverter { // append SQL projection if present rel = when (val projection = sel.select) { is Select.Project -> visitSelectProject(projection, rel) + is Select.Value -> visitSelectValue(projection, rel) is Select.Star -> error("AST not normalized, found project star") else -> rel // skip PIVOT and SELECT VALUE } @@ -256,7 +169,20 @@ internal object RelConverter { return rel(type, op) } - override fun visitFromValue(node: From.Value, nil: Rel): Rel { + /** + * Note: The name of the [Rel.Binding] doesn't actually matter. We make it unique to avoid any potential conflicts + * in the future. + */ + override fun visitSelectValue(node: Select.Value, input: Rel): Rel = plan { + val rex = RexConverter.apply(node.constructor, env) + val schema = listOf(relBinding(name = getNextUniqueName(), rex.type)) + val props = input.type.props + val type = relType(schema, props) + val op = relOpProject(input, projections = listOf(rex)) + rel(type, op) + } + + override fun visitFromValue(node: From.Value, nil: Rel) = plan { val rex = RexConverter.apply(node.expr, env) val binding = when (val a = node.asAlias) { null -> error("AST not normalized, missing AS alias on $node") @@ -265,7 +191,7 @@ internal object RelConverter { type = rex.type ) } - return when (node.type) { + when (node.type) { From.Value.Type.SCAN -> { when (val i = node.atAlias) { null -> convertScan(rex, binding) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt index a5ec405191..ca10558b7a 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt @@ -24,6 +24,7 @@ import org.partiql.ast.Type import org.partiql.ast.visitor.AstBaseVisitor import org.partiql.plan.Identifier import org.partiql.plan.Rex +import org.partiql.plan.builder.plan import org.partiql.plan.fnUnresolved import org.partiql.plan.identifierSymbol import org.partiql.plan.rex @@ -140,20 +141,30 @@ internal object RexConverter { override fun visitExprCall(node: Expr.Call, context: Env): Rex { val type = (StaticType.ANY) - // Args - val args = node.args.map { visitExpr(it, context) } // Fn val id = AstToPlan.convert(node.function) + if (id is Identifier.Symbol && id.symbol.equals("TUPLEUNION", ignoreCase = true)) { + return visitExprCallTupleUnion(node, context) + } val fn = fnUnresolved(id) + // Args + val args = node.args.map { visitExpr(it, context) } // Rex val op = rexOpCall(fn, args) return rex(type, op) } - override fun visitExprCase(node: Expr.Case, context: Env): Rex { + private fun visitExprCallTupleUnion(node: Expr.Call, context: Env) = plan { + val type = (StaticType.STRUCT) + val args = node.args.map { visitExpr(it, context) }.toMutableList() + val op = rexOpTupleUnion(args) + rex(type, op) + } + + override fun visitExprCase(node: Expr.Case, context: Env) = plan { val type = (StaticType.ANY) val rex = when (node.expr) { - null -> bool(true) // match `true` + null -> null else -> visitExpr(node.expr!!, context) // match `rex } @@ -161,8 +172,10 @@ internal object RexConverter { val id = identifierSymbol(Expr.Binary.Op.EQ.name.lowercase(), Identifier.CaseSensitivity.SENSITIVE) val fn = fnUnresolved(id) val createBranch: (Rex, Rex) -> Rex.Op.Case.Branch = { condition: Rex, result: Rex -> - val op = rexOpCall(fn.copy(), listOf(rex, condition)) - val updatedCondition = rex(type, op) + val updatedCondition = when (rex) { + null -> condition + else -> rex(type, rexOpCall(fn.copy(), listOf(rex, condition))) + } rexOpCaseBranch(updatedCondition, result) } @@ -178,7 +191,7 @@ internal object RexConverter { } branches += rexOpCaseBranch(bool(true), defaultRex) val op = rexOpCase(branches) - return rex(type, op) + rex(type, op) } override fun visitExprCollection(node: Expr.Collection, context: Env): Rex { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/ConstantFolder.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/ConstantFolder.kt new file mode 100644 index 0000000000..0350ff9392 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/ConstantFolder.kt @@ -0,0 +1,239 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.planner.typer + +import org.partiql.plan.Fn +import org.partiql.plan.PartiQLPlan +import org.partiql.plan.Rex +import org.partiql.plan.Statement +import org.partiql.plan.rex +import org.partiql.plan.rexOpLit +import org.partiql.plan.rexOpPath +import org.partiql.plan.rexOpPathStepIndex +import org.partiql.plan.rexOpStruct +import org.partiql.plan.rexOpStructField +import org.partiql.plan.statementQuery +import org.partiql.plan.util.PlanRewriter +import org.partiql.types.AnyOfType +import org.partiql.types.AnyType +import org.partiql.types.StaticType +import org.partiql.types.StructType +import org.partiql.value.BoolValue +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.StringValue +import org.partiql.value.TextValue +import org.partiql.value.boolValue +import org.partiql.value.stringValue + +/** + * Constant folds [Rex]'s. + * + * It's usefulness during typing can be seen in the following example: + * ``` + * SELECT VALUE + * TUPLEUNION( + * CASE WHEN t IS STRUCT THEN t ELSE { '_1': t } + * ) + * FROM << { 'a': 1 } >> AS t + * ``` + * + * If the [PlanTyper] were to type the TUPLEUNION, it would see the CASE statement which has two possible output types: + * 1. STRUCT( a: INT ) + * 2. STRUCT( _1: STRUCT( a: INT ) ) + * + * Therefore, the TUPLEUNION would have these two potential outputs. Now, what happens when you have multiple arguments + * that are all union types? It gets tricky, so we can fold things to make it easier to type. After constant folding, + * the expression will look like: + * ``` + * SELECT VALUE { 'a': t.a } + * FROM << { 'a': 1} >> AS t + * ``` + */ +internal object ConstantFolder { + + /** + * Constant folds an input [PartiQLPlan.statement]. + */ + internal fun fold(statement: Statement): Statement { + if (statement !is Statement.Query) { + throw IllegalArgumentException("ConstantFolder only supports Query statements") + } + val root = ConstantFolderImpl.visitRex(statement.root, statement.root.type) + return statementQuery(root) + } + + internal fun fold(rex: Rex): Rex { + return ConstantFolderImpl.visitRex(rex, rex.type) + } + + /** + * When visiting a [Rex.Op], please be sure to pass the associated [Rex.type] to the visitor. + */ + private object ConstantFolderImpl : PlanRewriter() { + + private fun fold(rex: Rex): Rex = visitRex(rex, rex.type) + + private fun default(op: Rex.Op, type: StaticType): Rex = rex(type, op) + + override fun visitRex(node: Rex, ctx: StaticType): Rex { + return visitRexOp(node.op, node.type) + } + + override fun visitRexOp(node: Rex.Op, ctx: StaticType): Rex { + return when (val folded = super.visitRexOp(node, ctx)) { + is Rex -> folded + is Rex.Op -> rex(ctx, folded) + else -> error("Expected to find Rex, but instead found $folded. We were visiting $node.") + } + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType): Rex { + val newBranches = node.branches.mapNotNull { branch -> + val conditionFolded = fold(branch.condition) + val conditionFoldedOp = conditionFolded.op as? Rex.Op.Lit ?: return@mapNotNull branch + val conditionBooleanLiteral = conditionFoldedOp.value as? BoolValue ?: return@mapNotNull branch + when (conditionBooleanLiteral.value) { + true -> branch.copy(conditionFolded, fold(branch.rex)) + false -> null + else -> branch + } + } + val firstBranch = newBranches.firstOrNull() ?: error("CASE_WHEN has NO branches.") + return when (isLiteralTrue(firstBranch.condition)) { + true -> firstBranch.rex + false -> default(node.copy(newBranches), ctx) + } + } + + @OptIn(PartiQLValueExperimental::class) + private fun isLiteralTrue(rex: Rex): Boolean { + val op = rex.op as? Rex.Op.Lit ?: return false + val value = op.value as? BoolValue ?: return false + return value.value ?: false + } + + @OptIn(PartiQLValueExperimental::class) + private fun getLiteral(rex: Rex): PartiQLValue? { + val op = rex.op as? Rex.Op.Lit ?: return null + return op.value + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType): Rex { + // Gather Struct Fields + val args = node.args.map { fold(it) } + val fields = args.flatMap { arg -> + val argType = arg.type.flatten() as? StructType ?: return default(node.copy(args = args), ctx) + if (argType.contentClosed.not()) { return default(node.copy(args = args), ctx) } + argType.fields.map { it to arg } + }.map { field -> + val fieldName = rex(StaticType.STRING, rexOpLit(stringValue(field.first.key))) + rexOpStructField( + k = fieldName, + v = rex( + field.first.value, + rexOpPath(field.second, steps = listOf(rexOpPathStepIndex(fieldName))) + ) + ) + } + return fold(rex(ctx, rexOpStruct(fields))) + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitRexOpPath(node: Rex.Op.Path, ctx: StaticType): Rex { + val struct = node.root.op as? Rex.Op.Struct ?: return default(node, ctx) + val step = node.steps.getOrNull(0) ?: return default(node, ctx) + val stepIndex = step as? Rex.Op.Path.Step.Index ?: return default(node, ctx) + val stepName = stepIndex.key.op as? Rex.Op.Lit ?: return default(node, ctx) + val stepNameString = stepName.value as? TextValue<*> ?: return default(node, ctx) + val matches = struct.fields.filter { field -> + val fieldName = field.k.op as? Rex.Op.Lit ?: return default(node, ctx) + val fieldNameString = fieldName.value as? StringValue ?: return default(node, ctx) + val fieldNameStringValue = fieldNameString.value ?: return default(node, ctx) + fieldNameStringValue == stepNameString.string + } + return when (matches.size) { + 1 -> matches[0].v + else -> default(node, ctx) + } + } + + /** + * We expect all variants of [visitRexOpCall] to visit their own arguments. + * TODO: Function signature case sensitivity + */ + override fun visitRexOpCall(node: Rex.Op.Call, ctx: StaticType): Rex { + val fn = node.fn as? Fn.Resolved ?: return default(node, ctx) + return when { + fn.signature.name.equals("is_struct", ignoreCase = true) -> visitRexOpCallIsStruct(node, ctx) + fn.signature.name.equals("eq", ignoreCase = true) -> visitRexOpCallEq(node, ctx) + fn.signature.name.equals("not", ignoreCase = true) -> visitRexOpCallNot(node, ctx) + else -> rex(ctx, node) + } + } + + /** + * This relies on the fact that [Rex.equals] works and [PartiQLValue.equals] works. + */ + @OptIn(PartiQLValueExperimental::class) + private fun visitRexOpCallEq(folded: Rex.Op.Call, ctx: StaticType): Rex { + val lhs = folded.args.getOrNull(0) ?: error("EQ should have a LHS argument.") + val rhs = folded.args.getOrNull(1) ?: error("EQ should have a RHS argument.") + val lhsFolded = fold(lhs) + val rhsFolded = fold(rhs) + // Same expressions + if (lhsFolded == rhsFolded) { + return rex(StaticType.BOOL, rexOpLit(boolValue(true))) + } + val lhsLiteral = getLiteral(lhsFolded) ?: return default(folded, ctx) + val rhsLiteral = getLiteral(rhsFolded) ?: return default(folded, ctx) + return rex(StaticType.BOOL, rexOpLit(boolValue(lhsLiteral == rhsLiteral))) + } + + /** + * This relies on the fact that [Rex.equals] works and [PartiQLValue.equals] works. + */ + @OptIn(PartiQLValueExperimental::class) + private fun visitRexOpCallNot(folded: Rex.Op.Call, ctx: StaticType): Rex { + val lhs = folded.args.getOrNull(0) ?: error("NOT should have a LHS argument.") + val lhsFolded = fold(lhs) + val lhsLiteral = getLiteral(lhsFolded) ?: return default(folded, ctx) + val booleanValue = lhsLiteral as? BoolValue ?: return default(folded, ctx) + val boolean = booleanValue.value ?: return default(folded, ctx) + return rex(StaticType.BOOL, rexOpLit(boolValue(boolean.not()))) + } + + @OptIn(PartiQLValueExperimental::class) + private fun visitRexOpCallIsStruct(folded: Rex.Op.Call, ctx: StaticType): Rex { + val isStructLhs = folded.args.getOrNull(0) ?: error("IS STRUCT should have a LHS argument.") + return when (val resultType = isStructLhs.type.flatten()) { + is StructType -> rex(StaticType.BOOL, rexOpLit(boolValue(true))) + is AnyType -> default(folded, ctx) + is AnyOfType -> { + when { + resultType.allTypes.any { it is AnyType } -> default(folded, ctx) + resultType.allTypes.any { it is AnyOfType } -> error("Flattened union types shouldn't contain unions") + resultType.allTypes.all { it is StructType } -> rex(StaticType.BOOL, rexOpLit(boolValue(true))) + resultType.allTypes.any { it is StructType }.not() -> rex(StaticType.BOOL, rexOpLit(boolValue(false))) + else -> default(folded, ctx) + } + } + else -> default(folded, ctx) + } + } + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt index cc16b86ad0..2959aba992 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt @@ -21,9 +21,11 @@ import org.partiql.errors.ProblemCallback import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION import org.partiql.plan.Fn import org.partiql.plan.Identifier +import org.partiql.plan.PlanNode import org.partiql.plan.Rel import org.partiql.plan.Rex import org.partiql.plan.Statement +import org.partiql.plan.builder.plan import org.partiql.plan.fnResolved import org.partiql.plan.identifierSymbol import org.partiql.plan.rel @@ -129,6 +131,11 @@ internal class PlanTyper( return rel(type, op) } + override fun visitRelOpErr(node: Rel.Op.Err, ctx: Rel.Type?): Rel { + val type = ctx ?: relType(emptyList(), emptySet()) + return rel(type, node) + } + /** * The output schema of a `rel.op.scan_index` is the value binding and index binding. */ @@ -531,16 +538,44 @@ internal class PlanTyper( } } - override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex { - val visitedBranches = node.branches.map { visitRexOpCaseBranch(it, null) } + override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex = plan { + val visitedBranches = node.branches.map { visitRexOpCaseBranch(it, it.rex.type) } val resultTypes = visitedBranches.map { it.rex }.map { it.type } - return rex(AnyOfType(resultTypes.toSet()).flatten(), node.copy(branches = visitedBranches)) + rex(AnyOfType(resultTypes.toSet()).flatten(), node.copy(branches = visitedBranches)) } + /** + * We need special handling for: + * ``` + * CASE + * WHEN a IS STRUCT THEN a + * ELSE { 'a': a } + * END + * ``` + * When we type the above, we can't just assume + */ override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: StaticType?): Rex.Op.Case.Branch { - val visitedCondition = visitRex(node.condition, null) - val visitedReturn = visitRex(node.rex, null) - return node.copy(condition = visitedCondition, rex = visitedReturn) + val visitedCondition = visitRex(node.condition, node.condition.type) + val visitedReturn = visitRex(node.rex, node.rex.type) + val rex = handleSmartCasts(visitedCondition, visitedReturn) ?: visitedReturn + return node.copy(condition = visitedCondition, rex = rex) + } + + /** + * This takes in a [Rex.Op.Case.Branch.condition] and [Rex.Op.Case.Branch.rex]. If the [condition] is a type check, + * AKA ` IS STRUCT`, then this function will return a [Rex.Op.Case.Branch.rex] that assumes that the `` + * is a struct. + */ + private fun handleSmartCasts(condition: Rex, result: Rex): Rex? { + val call = condition.op as? Rex.Op.Call ?: return null + val fn = call.fn as? Fn.Resolved ?: return null + if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { return null } + val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") + if (ref.op !is Rex.Op.Var.Resolved) { return null } + val structTypes = ref.type.allTypes.filterIsInstance() + val type = AnyOfType(structTypes.toSet()) + val replacementVal = ref.copy(type = type) + return RexReplacer.replace(result, ref, replacementVal) } override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Rex { @@ -548,7 +583,7 @@ internal class PlanTyper( handleUnexpectedType(ctx, setOf(StaticType.LIST, StaticType.BAG, StaticType.SEXP)) return rex(StaticType.NULL_OR_MISSING, rexOpErr("Expected collection type")) } - val values = node.values.map { visitRex(it, null) } + val values = node.values.map { visitRex(it, it.type) } val t = values.toUnionType() val type = when (ctx as CollectionType) { is BagType -> BagType(t) @@ -561,8 +596,8 @@ internal class PlanTyper( @OptIn(PartiQLValueExperimental::class) override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Rex { val fields = node.fields.map { - val k = visitRex(it.k, null) - val v = visitRex(it.v, null) + val k = visitRex(it.k, it.k.type) + val v = visitRex(it.v, it.v.type) rexOpStructField(k, v) } var structIsClosed = true @@ -620,6 +655,7 @@ internal class PlanTyper( var constructorType = constructor.type // add the ordered property to the constructor if (constructorType is StructType) { + // TODO: We shouldn't need to copy the ordered constraint. constructorType = constructorType.copy( constraints = constructorType.constraints + setOf(TupleConstraint.Ordered) ) @@ -632,60 +668,99 @@ internal class PlanTyper( return rex(type, rexOpSelect(constructor, rel)) } - override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Rex { - val args = node.args.map { visitTupleUnionArg(it) } + override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Rex = plan { + val args = node.args.map { visitRex(it, ctx) } + val argTypes = args.map { it.type } + val potentialTypes = buildArgumentPermutations(argTypes).map { argumentList -> + calculateTupleUnionOutputType(argumentList) + } + val op = rexOpTupleUnion(args) + rex(StaticType.unionOf(potentialTypes.toSet()).flatten(), op) + } + + override fun visitRexOpErr(node: Rex.Op.Err, ctx: StaticType?): PlanNode { + val type = ctx ?: StaticType.ANY + return rex(type, node) + } + + // Helpers + + private fun calculateTupleUnionOutputType(args: List): StaticType { val structFields = mutableListOf() + var structAmount = 0 var structIsClosed = true - for (arg in args) { + var structIsOrdered = true + var canReturnStruct = false + var uniqueAttrs = true + val possibleOutputTypes = mutableListOf() + args.forEach { arg -> when (arg) { - is Rex.Op.TupleUnion.Arg.Spread -> { - val t = arg.v.type - if (t is StructType) { - // arg is definitely a struct - structFields.addAll(t.fields) - structIsClosed = structIsClosed && t.contentClosed - } else if (t.allTypes.filterIsInstance().isNotEmpty()) { - // arg is possibly a struct, just declare OPEN content - structIsClosed = false - } else { - // arg is definitely NOT a struct - val field = StructType.Field(arg.k, arg.v.type) - structFields.add(field) - } + is StructType -> { + structAmount += 1 + canReturnStruct = true + structFields.addAll(arg.fields) + structIsClosed = structIsClosed && arg.constraints.contains(TupleConstraint.Open(false)) + structIsOrdered = structIsOrdered && arg.constraints.contains(TupleConstraint.Ordered) + uniqueAttrs = uniqueAttrs && arg.constraints.contains(TupleConstraint.UniqueAttrs(true)) } - is Rex.Op.TupleUnion.Arg.Struct -> { - val field = StructType.Field(arg.k, arg.v.type) - structFields.add(field) + is AnyOfType -> { + onProblem.invoke( + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.CompileError("TupleUnion wasn't normalized to exclude union types.") + ) + ) + possibleOutputTypes.add(StaticType.MISSING) } + is NullType -> { possibleOutputTypes.add(StaticType.NULL) } + else -> { possibleOutputTypes.add(StaticType.MISSING) } } } - val type = StructType( - fields = structFields, - contentClosed = structIsClosed, - constraints = setOf( - TupleConstraint.Open(!structIsClosed), - TupleConstraint.UniqueAttrs(structFields.size == structFields.map { it.key }.distinct().size), - TupleConstraint.Ordered, - ), + uniqueAttrs = when { + structIsClosed.not() && structAmount > 1 -> false + structFields.distinctBy { it.key }.size != structFields.size -> false + else -> uniqueAttrs + } + val orderedConstraint = when (structIsOrdered) { + true -> TupleConstraint.Ordered + false -> null + } + val constraints = setOfNotNull( + TupleConstraint.Open(!structIsClosed), + TupleConstraint.UniqueAttrs(uniqueAttrs), + orderedConstraint ) - val op = rexOpTupleUnion(args) - return rex(type, op) + if (canReturnStruct) { + uniqueAttrs = uniqueAttrs && (structFields.size == structFields.map { it.key }.distinct().size) + possibleOutputTypes.add( + StructType( + fields = structFields.map { it }, + contentClosed = structIsClosed, + constraints = constraints + ) + ) + } + return StaticType.unionOf(possibleOutputTypes.toSet()).flatten() } - private fun visitTupleUnionArg(node: Rex.Op.TupleUnion.Arg) = when (node) { - is Rex.Op.TupleUnion.Arg.Spread -> visitRexOpTupleUnionArgSpread(node, null) - is Rex.Op.TupleUnion.Arg.Struct -> visitRexOpTupleUnionArgStruct(node, null) + private fun buildArgumentPermutations(args: List): Sequence> { + val flattenedArgs = args.map { it.flatten().allTypes } + return buildArgumentPermutations(flattenedArgs, accumulator = emptyList()) } - override fun visitRexOpTupleUnionArgStruct( - node: Rex.Op.TupleUnion.Arg.Struct, - ctx: StaticType?, - ) = super.visitRexOpTupleUnionArgStruct(node, ctx) as Rex.Op.TupleUnion.Arg - - override fun visitRexOpTupleUnionArgSpread( - node: Rex.Op.TupleUnion.Arg.Spread, - ctx: StaticType?, - ) = super.visitRexOpTupleUnionArgSpread(node, ctx) as Rex.Op.TupleUnion.Arg + private fun buildArgumentPermutations(args: List>, accumulator: List): Sequence> { + if (args.isEmpty()) { return sequenceOf(accumulator) } + val first = args.first() + val rest = when (args.size) { + 1 -> emptyList() + else -> args.subList(1, args.size) + } + return sequence { + first.forEach { argSubType -> + yieldAll(buildArgumentPermutations(rest, accumulator + listOf(argSubType))) + } + } + } // Helpers @@ -765,7 +840,7 @@ internal class PlanTyper( private fun Rel.type(typeEnv: TypeEnv): Rel = RelTyper(typeEnv).visitRel(this, null) - private fun Rex.type(typeEnv: TypeEnv) = RexTyper(typeEnv).visitRex(this, null) + private fun Rex.type(typeEnv: TypeEnv) = RexTyper(typeEnv).visitRex(this, this.type) /** * I found decorating the tree with the binding names (for resolution) was easier than associating introduced @@ -886,7 +961,7 @@ internal class PlanTyper( /** * Constructs a Rex.Op.Path from a resolved local */ - private fun resolvedLocalPath(local: ResolvedVar.Local): Rex.Op.Path { + private fun resolvedLocalPath(local: ResolvedVar.Local): Rex.Op { val root = rex(local.rootType, rexOpVarResolved(local.ordinal)) val steps = local.tail.map { val case = when (it.bindingCase) { @@ -895,7 +970,10 @@ internal class PlanTyper( } rexOpPathStepSymbol(identifierSymbol(it.name, case)) } - return rexOpPath(root, steps) + return when (steps.isEmpty()) { + true -> root.op + false -> rexOpPath(root, steps) + } } // ERRORS diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/RexReplacer.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/RexReplacer.kt new file mode 100644 index 0000000000..0490b56e70 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/RexReplacer.kt @@ -0,0 +1,42 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.planner.typer + +import org.partiql.plan.Rex +import org.partiql.plan.util.PlanRewriter + +/** + * Uses to replace [Rex]'s within an expression tree. + */ +internal object RexReplacer { + + /** + * Within the [Rex] tree of [rex], replaces all instances of [replace] with the [with]. + */ + internal fun replace(rex: Rex, replace: Rex, with: Rex): Rex { + val params = ReplaceParams(replace, with) + return RexReplacerImpl.visitRex(rex, params) + } + + private class ReplaceParams(val replace: Rex, val with: Rex) + + private object RexReplacerImpl : PlanRewriter() { + + override fun visitRex(node: Rex, ctx: ReplaceParams): Rex { + if (node == ctx.replace) { return ctx.with } + return visitRexOp(node.op, ctx) as Rex + } + } +} diff --git a/partiql-planner/src/testFixtures/kotlin/org/partiql/planner/test/Parsing.kt b/partiql-planner/src/testFixtures/kotlin/org/partiql/planner/test/Parsing.kt index 08d7894ac8..2db05ca7aa 100644 --- a/partiql-planner/src/testFixtures/kotlin/org/partiql/planner/test/Parsing.kt +++ b/partiql-planner/src/testFixtures/kotlin/org/partiql/planner/test/Parsing.kt @@ -28,6 +28,7 @@ import org.partiql.types.StructType import org.partiql.types.SymbolType import org.partiql.types.TimeType import org.partiql.types.TimestampType +import org.partiql.types.TupleConstraint // Use some generated serde eventually @@ -127,7 +128,7 @@ public fun StructElement.toStructType(): StaticType { val type = it.getAngry("type").toStaticType() StructType.Field(name, type) } - return StructType(fields) + return StructType(fields, contentClosed = true, constraints = setOf(TupleConstraint.Open(false))) } public fun StaticType.toIon(): IonElement = when (this) { diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/pql/employer.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/employer.ion new file mode 100644 index 0000000000..15e2b06215 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/employer.ion @@ -0,0 +1,30 @@ +{ + type: "struct", + name: "employer", + fields: [ + { + name: "name", + type: "string" + }, + { + name: "tax_id", + type: "int64" + }, + { + name: "address", + type: { + type: "struct", + fields: [ + { + name: "street", + type: "string" + }, + { + name: "zip", + type: "int32" + }, + ] + }, + }, + ] +} diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/pql/person.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/person.ion new file mode 100644 index 0000000000..6b1ea7e207 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/person.ion @@ -0,0 +1,33 @@ +{ + type: "struct", + name: "person", + fields: [ + { + name: "name", + type: { + type: "struct", + fields: [ + { + name: "first", + type: "string" + }, + { + name: "last", + type: "string" + }, + ] + }, + }, + { + name: "ssn", + type: "string" + }, + { + name: "company", + type: [ + "string", + "null" + ] + } + ] +} diff --git a/partiql-planner/src/testFixtures/resources/tests/suite_00.ion b/partiql-planner/src/testFixtures/resources/tests/suite_00.ion index cf073b32cb..3da056b74e 100644 --- a/partiql-planner/src/testFixtures/resources/tests/suite_00.ion +++ b/partiql-planner/src/testFixtures/resources/tests/suite_00.ion @@ -125,28 +125,99 @@ suite::{ } } }, - // TODO: Add support for SELECT * so we can assert on the schema '0004': { statement: ''' - SELECT s_store_sk + SELECT p.*, e.* FROM - tpc_ds.store AS store - LEFT JOIN - tpc_ds.store_returns AS returns - ON s_store_sk = sr_store_sk + pql.person AS p + INNER JOIN + pql.employer AS e + ON p.employer = e.name ''', schema: { type: "bag", items: { - type: "struct", - fields: [ + type:"struct", + fields:[ + { + name:"name", + type:{ + type:"struct", + fields:[ + { + name:"first", + type:"string" + }, + { + name:"last", + type:"string" + } + ] + } + }, + { + name:"ssn", + type:"string" + }, { - name:"s_store_sk", + name:"company", + type:[ + "null", + "string" + ] + }, + { + name:"name", type:"string" + }, + { + name:"tax_id", + type:"int64" + }, + { + name:"address", + type:{ + type:"struct", + fields:[ + { + name:"street", + type:"string" + }, + { + name:"zip", + type:"int32" + } + ] + } } ] } } }, + '0005': { + statement: ''' + SELECT p.name.*, (p.name."first" || ' ' || p.name."last") AS full_name FROM pql.person AS p + ''', + schema: { + type:"bag", + items:{ + type:"struct", + fields:[ + { + name:"first", + type:"string" + }, + { + name:"last", + type:"string" + }, + { + name:"full_name", + type:"string" + } + ] + } + } + } }, }