diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt index 93c169264b..2ce5348662 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt @@ -715,7 +715,13 @@ internal class Header( returns = BOOL, parameters = listOf(FunctionParameter("value", BOOL)), isNullable = true, - ) + ), + FunctionSignature.Aggregation( + name = "every", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ), ) private fun any() = listOf( @@ -724,7 +730,13 @@ internal class Header( returns = BOOL, parameters = listOf(FunctionParameter("value", BOOL)), isNullable = true, - ) + ), + FunctionSignature.Aggregation( + name = "any", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ), ) private fun some() = listOf( @@ -733,7 +745,13 @@ internal class Header( returns = BOOL, parameters = listOf(FunctionParameter("value", BOOL)), isNullable = true, - ) + ), + FunctionSignature.Aggregation( + name = "some", + returns = BOOL, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ), ) private fun count() = listOf( @@ -742,7 +760,7 @@ internal class Header( returns = INT, parameters = listOf(FunctionParameter("value", ANY)), isNullable = false, - ) + ), ) private fun min() = numericTypes.map { @@ -752,7 +770,12 @@ internal class Header( parameters = listOf(FunctionParameter("value", it)), isNullable = true, ) - } + } + FunctionSignature.Aggregation( + name = "min", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) private fun max() = numericTypes.map { FunctionSignature.Aggregation( @@ -761,7 +784,12 @@ internal class Header( parameters = listOf(FunctionParameter("value", it)), isNullable = true, ) - } + } + FunctionSignature.Aggregation( + name = "max", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) private fun sum() = numericTypes.map { FunctionSignature.Aggregation( @@ -770,7 +798,12 @@ internal class Header( parameters = listOf(FunctionParameter("value", it)), isNullable = true, ) - } + } + FunctionSignature.Aggregation( + name = "sum", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) private fun avg() = numericTypes.map { FunctionSignature.Aggregation( @@ -779,7 +812,12 @@ internal class Header( parameters = listOf(FunctionParameter("value", it)), isNullable = true, ) - } + } + FunctionSignature.Aggregation( + name = "avg", + returns = ANY, + parameters = listOf(FunctionParameter("value", ANY)), + isNullable = true, + ) // ==================================== // SORTING diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt index 9bee396d58..8cfd417bdc 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RelConverter.kt @@ -405,7 +405,7 @@ internal object RelConverter { return Pair(select, input) } - // Build the schema -> (aggs... groups...) + // Build the schema -> (calls... groups...) val schema = mutableListOf() val props = emptySet() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt index cf65fa2f01..262557a4fe 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt @@ -19,15 +19,18 @@ package org.partiql.planner.typer import org.partiql.errors.Problem import org.partiql.errors.ProblemCallback import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION +import org.partiql.plan.Agg import org.partiql.plan.Fn import org.partiql.plan.Identifier import org.partiql.plan.Rel import org.partiql.plan.Rex import org.partiql.plan.Statement +import org.partiql.plan.aggResolved import org.partiql.plan.fnResolved import org.partiql.plan.identifierSymbol import org.partiql.plan.rel import org.partiql.plan.relBinding +import org.partiql.plan.relOpAggregateCall import org.partiql.plan.relOpErr import org.partiql.plan.relOpFilter import org.partiql.plan.relOpJoin @@ -293,19 +296,20 @@ internal class PlanTyper( /** * Initial implementation of `EXCLUDE` schema inference. Until an RFC is finalized for `EXCLUDE` - * (https://github.com/partiql/partiql-spec/issues/39), this behavior is considered experimental and subject to - * change. + * (https://github.com/partiql/partiql-spec/issues/39), * - * So far this implementation includes + * This behavior is considered experimental and subject to change. + * + * This implementation includes * - Excluding tuple bindings (e.g. t.a.b.c) * - Excluding tuple wildcards (e.g. t.a.*.b) * - Excluding collection indexes (e.g. t.a[0].b -- behavior subject to change; see below discussion) * - Excluding collection wildcards (e.g. t.a[*].b) * * There are still discussion points regarding the following edge cases: - * - EXCLUDE on a tuple bindingibute that doesn't exist -- give an error/warning? + * - EXCLUDE on a tuple attribute that doesn't exist -- give an error/warning? * - currently no error - * - EXCLUDE on a tuple bindingibute that has duplicates -- give an error/warning? exclude one? exclude both? + * - EXCLUDE on a tuple attribute that has duplicates -- give an error/warning? exclude one? exclude both? * - currently excludes both w/ no error * - EXCLUDE on a collection index as the last step -- mark element type as optional? * - currently element type as-is @@ -315,7 +319,7 @@ internal class PlanTyper( * - currently a parser error * - EXCLUDE on a union type -- give an error/warning? no-op? exclude on each type in union? * - currently exclude on each union type - * - If SELECT list includes an bindingibute that is excluded, we could consider giving an error in PlanTyper or + * - If SELECT list includes an attribute that is excluded, we could consider giving an error in PlanTyper or * some other semantic pass * - currently does not give an error */ @@ -333,7 +337,33 @@ internal class PlanTyper( } override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: Rel.Type?): Rel { - TODO("Type RelOp Aggregate") + // compute input schema + val input = visitRel(node.input, ctx) + + // type the calls and groups + val typer = RexTyper(locals = TypeEnv(input.type.schema, ResolutionStrategy.LOCAL)) + + // typing of aggregate calls it slightly more complicated because they are not expressions. + val calls = node.calls.mapIndexed { i, call -> + when (val agg = call.agg) { + is Agg.Resolved -> call to ctx!!.schema[i].type + is Agg.Unresolved -> typer.resolveAgg(agg, call.args) + } + } + val groups = node.groups.map { typer.visitRex(it, null) } + + // Compute schema using order (calls...groups...) + val schema = mutableListOf() + schema += calls.map { it.second } + schema += groups.map { it.type } + + // rewrite with typed calls and groups + val type = ctx!!.copyWithSchema(schema) + val op = node.copy( + calls = calls.map { it.first }, + groups = groups, + ) + return rel(type, op) } override fun visitRelBinding(node: Rel.Binding, ctx: Rel.Type?): Rel { @@ -441,7 +471,7 @@ internal class PlanTyper( // 4. Invalid path reference; always MISSING if (type == StaticType.MISSING) { handleAlwaysMissing() - return rex(type, rexOpErr("Unknown identifier $node")) + return rexErr("Unknown identifier $node") } // 5. Non-missing, root is resolved @@ -449,10 +479,7 @@ internal class PlanTyper( } /** - * Typing of functions is - * - * 1. If any argument is MISSING, the function return type is MISSING - * 2. If all arguments are NULL + * Resolve and type scalar function calls. * * @param node * @param ctx @@ -522,7 +549,7 @@ internal class PlanTyper( } is FnMatch.Error -> { handleUnknownFunction(match) - rex(StaticType.MISSING, rexOpErr("Unknown scalar function $fn")) + rexErr("Unknown scalar function $fn") } } } @@ -755,6 +782,74 @@ internal class PlanTyper( false -> StaticType.ANY } } + + /** + * Resolution and typing of aggregation function calls. + * + * I've chosen to place this in RexTyper because all arguments will be typed using the same locals. + * There's no need to create new RexTyper instances for each argument. There is no reason to limit aggregations + * to a single argument (covar, corr, pct, etc.) but in practice we typically only have single . + * + * This method is _very_ similar to scalar function resolution, so it is temping to DRY these two out; but the + * separation is cleaner as the typing of NULLS is subtly different. + * + * SQL-99 6.16 General Rules on + * Let TX be the single-column table that is the result of applying the + * to each row of T and eliminating null values <--- all NULL values are eliminated as inputs + */ + public fun resolveAgg(agg: Agg.Unresolved, arguments: List): Pair { + var missingArg = false + val args = arguments.map { + val arg = visitRex(it, null) + if (arg.type == MissingType) missingArg = true + arg + } + + // + if (missingArg) { + handleAlwaysMissing() + return relOpAggregateCall(agg, listOf(rexErr("MISSING"))) to MissingType + } + + // Try to match the arguments to functions defined in the catalog + return when (val match = env.resolveAgg(agg, args)) { + is FnMatch.Ok -> { + // Found a match! + val newAgg = aggResolved(match.signature) + val newArgs = rewriteFnArgs(match.mapping, args) + val returns = newAgg.signature.returns + + // Determine the nullability of the return type + var isNullable = false // True iff has a NULLABLE arg and is a NULLABLE operator + if (newAgg.signature.isNullable) { + for (arg in newArgs) { + if (arg.type.isNullable()) { + isNullable = true + break + } + } + } + + // Return type with calculated nullability + var type = when { + isNullable -> returns.toStaticType() + else -> returns.toNonNullStaticType() + } + + // Some operators can return MISSING during runtime + if (match.isMissable) { + type = StaticType.unionOf(type, StaticType.MISSING) + } + + // Finally, rewrite this node + relOpAggregateCall(newAgg, newArgs) to type + } + is FnMatch.Error -> { + handleUnknownFunction(match) + return relOpAggregateCall(agg, listOf(rexErr("MISSING"))) to MissingType + } + } + } } // HELPERS @@ -763,6 +858,8 @@ internal class PlanTyper( private fun Rex.type(typeEnv: TypeEnv) = RexTyper(typeEnv).visitRex(this, null) + private fun rexErr(message: String) = rex(StaticType.MISSING, rexOpErr(message)) + /** * I found decorating the tree with the binding names (for resolution) was easier than associating introduced * bindings with a node via an id->list map. ONLY because right now I don't think we have a good way diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerTestJunit.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerTestJunit.kt index e752ec0033..19b41ff064 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerTestJunit.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerTestJunit.kt @@ -3,7 +3,6 @@ package org.partiql.planner import com.amazon.ionelement.api.field import com.amazon.ionelement.api.ionString import com.amazon.ionelement.api.ionStructOf -import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.DynamicContainer import org.junit.jupiter.api.DynamicContainer.dynamicContainer import org.junit.jupiter.api.DynamicNode @@ -14,6 +13,7 @@ import org.junit.jupiter.api.fail import org.partiql.errors.ProblemSeverity import org.partiql.parser.PartiQLParserBuilder import org.partiql.plan.Statement +import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.test.PlannerTest import org.partiql.planner.test.PlannerTestProvider import org.partiql.planner.test.PlannerTestSuite @@ -83,7 +83,15 @@ class PlannerTestJunit { } val expected = test.schema.toIon() val actual = statement.root.type.toIon() - assertEquals(expected, actual) + assert(expected == actual) { + buildString { + appendLine() + appendLine("Expect: $expected") + appendLine("Actual: $actual") + appendLine() + PlanPrinter.append(this, statement) + } + } } } } diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/pql/numbers.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/numbers.ion new file mode 100644 index 0000000000..14c03f7135 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/numbers.ion @@ -0,0 +1,114 @@ +{ + type: "struct", + fields: [ + { + name: "nullable_int16s", + type: { + type: "list", + items: [ + "int16", + "null" + ] + } + }, + { + name: "nullable_int32s", + type: { + type: "list", + items: [ + "int32", + "null" + ] + } + }, + { + name: "nullable_int64s", + type: { + type: "list", + items: [ + "int64", + "null" + ] + } + }, + { + name: "nullable_ints", + type: { + type: "list", + items: [ + "int", + "null" + ] + } + }, + { + name: "int16s", + type: { + type: "list", + items: "int16", + }, + }, + { + name: "int32s", + type: { + type: "list", + items: "int32", + }, + }, + { + name: "int64s", + type: { + type: "list", + items: "int64", + }, + }, + { + name: "ints", + type: { + type: "list", + items: "int", + }, + }, + { + name: "decimals", + type: { + type: "list", + items: "decimal", + }, + }, + { + name: "nullable_float32s", + type: { + type: "list", + items: [ + "float32", + "null" + ] + } + }, + { + name: "nullable_float64s", + type: { + type: "list", + items: [ + "float64", + "null" + ] + } + }, + { + name: "float32s", + type: { + type: "list", + items: "float32", + }, + }, + { + name: "float64s", + type: { + type: "list", + items: "float64", + }, + } + ], +} diff --git a/partiql-planner/src/testFixtures/resources/tests/aggregations.ion b/partiql-planner/src/testFixtures/resources/tests/aggregations.ion index b7a650c992..70d7c8ecaa 100644 --- a/partiql-planner/src/testFixtures/resources/tests/aggregations.ion +++ b/partiql-planner/src/testFixtures/resources/tests/aggregations.ion @@ -1,15 +1,16 @@ -// https://web.cecs.pdx.edu/~len/sql1999.pdf#page=181 suite::{ name: "aggregations", session: { - catalog: "default", // session catalog - path: ["tpc_ds"], // session path - vars: {}, // session variables + catalog: "default", + path: [ + "pql" + ], + vars: {}, }, tests: { - 'avg(int32)': { + 'avg(int32|null)': { statement: ''' - SELECT AVG(ss_quantity) FROM store_sales + SELECT AVG(n) as "avg" FROM numbers.nullable_int32s AS n ''', schema: { type: "bag", @@ -17,16 +18,19 @@ suite::{ type: "struct", fields: [ { - name: "_1", - type: "float64", + name: "avg", + type: [ + "int32", + "null", + ], }, ], }, }, }, - 'sum(int32)': { + 'count(int32|null)': { statement: ''' - SELECT SUM(ss_quantity) FROM store_sales + SELECT COUNT(n) as "count" FROM numbers.nullable_int32s AS n ''', schema: { type: "bag", @@ -34,63 +38,12 @@ suite::{ type: "struct", fields: [ { - name: "_1", + name: "count", type: "int", }, ], }, }, }, - 'min(int32)': { - statement: ''' - SELECT MIN(ss_quantity) FROM store_sales - ''', - schema: { - type: "bag", - items: { - type: "struct", - fields: [ - { - name: "_1", - type: "int32", - }, - ], - }, - }, - }, - 'max(int32)': { - statement: ''' - SELECT MAX(ss_quantity) FROM store_sales - ''', - schema: { - type: "bag", - items: { - type: "struct", - fields: [ - { - name: "_1", - type: "int32", - }, - ], - }, - }, - }, - 'count(int32)': { - statement: ''' - SELECT COUNT(1) FROM store_sales - ''', - schema: { - type: "bag", - items: { - type: "struct", - fields: [ - { - name: "_1", - type: "int64", - }, - ], - }, - }, - }, }, }