diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt index 61267b6f6..08d1d7310 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt @@ -15,15 +15,17 @@ package org.partiql.cli -import AstPrinter import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ion.system.IonTextWriterBuilder import org.partiql.cli.pico.PartiQLCommand import org.partiql.cli.shell.info import org.partiql.lang.eval.EvaluationSession import org.partiql.parser.PartiQLParser +import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner import org.partiql.plugins.local.LocalConnector +import org.partiql.plugins.local.toIon import picocli.CommandLine import java.io.PrintStream import java.nio.file.Paths @@ -80,6 +82,16 @@ object Debug { out.info("-- Plan ----------") PlanPrinter.append(out, result.statement) + when (val plan = result.statement) { + is Statement.Query -> { + out.info("-- Schema ----------") + val outputSchema = java.lang.StringBuilder() + val ionWriter = IonTextWriterBuilder.minimal().withPrettyPrinting().build(outputSchema) + plan.root.type.toIon().writeTo(ionWriter) + out.info(outputSchema.toString()) + } + } + return "OK" } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt new file mode 100644 index 000000000..eb271817b --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/DynamicTyper.kt @@ -0,0 +1,373 @@ +@file:OptIn(PartiQLValueExperimental::class) + +package org.partiql.planner.internal.typer + +import org.partiql.types.MissingType +import org.partiql.types.NullType +import org.partiql.types.StaticType +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType +import org.partiql.value.PartiQLValueType.ANY +import org.partiql.value.PartiQLValueType.BAG +import org.partiql.value.PartiQLValueType.BINARY +import org.partiql.value.PartiQLValueType.BLOB +import org.partiql.value.PartiQLValueType.BOOL +import org.partiql.value.PartiQLValueType.BYTE +import org.partiql.value.PartiQLValueType.CHAR +import org.partiql.value.PartiQLValueType.CLOB +import org.partiql.value.PartiQLValueType.DATE +import org.partiql.value.PartiQLValueType.DECIMAL +import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY +import org.partiql.value.PartiQLValueType.FLOAT32 +import org.partiql.value.PartiQLValueType.FLOAT64 +import org.partiql.value.PartiQLValueType.INT +import org.partiql.value.PartiQLValueType.INT16 +import org.partiql.value.PartiQLValueType.INT32 +import org.partiql.value.PartiQLValueType.INT64 +import org.partiql.value.PartiQLValueType.INT8 +import org.partiql.value.PartiQLValueType.INTERVAL +import org.partiql.value.PartiQLValueType.LIST +import org.partiql.value.PartiQLValueType.MISSING +import org.partiql.value.PartiQLValueType.NULL +import org.partiql.value.PartiQLValueType.SEXP +import org.partiql.value.PartiQLValueType.STRING +import org.partiql.value.PartiQLValueType.STRUCT +import org.partiql.value.PartiQLValueType.SYMBOL +import org.partiql.value.PartiQLValueType.TIME +import org.partiql.value.PartiQLValueType.TIMESTAMP + +/** + * Graph of super types for quick lookup because we don't have a tree. + */ +internal typealias SuperGraph = Array> + +/** + * For lack of a better name, this is the "dynamic typer" which implements the typing rules of SQL-99 9.3. + * + * SQL-99 9.3 Data types of results of aggregations (, , ) + * > https://web.cecs.pdx.edu/~len/sql1999.pdf#page=359 + * + * Usage, + * To calculate the type of an "aggregation" create a new instance and "accumulate" each possible type. + * This is a pain with StaticType... + */ +@OptIn(PartiQLValueExperimental::class) +internal class DynamicTyper { + + private var supertype: PartiQLValueType? = null + private var args = mutableListOf() + + private var nullable = false + private var missable = false + private val types = mutableSetOf() + + /** + * This primarily unpacks a StaticType because of NULL, MISSING. + * + * - T + * - NULL + * - MISSING + * - (NULL) + * - (MISSING) + * - (T..) + * - (T..|NULL) + * - (T..|MISSING) + * - (T..|NULL|MISSING) + * - (NULL|MISSING) + * + * @param type + */ + fun accumulate(type: StaticType) { + val nonAbsentTypes = mutableSetOf() + for (t in type.flatten().allTypes) { + when (t) { + is NullType -> nullable = true + is MissingType -> missable = true + else -> nonAbsentTypes.add(t) + } + } + when (nonAbsentTypes.size) { + 0 -> { + // Ignore in calculating supertype. + args.add(NULL) + } + 1 -> { + // Had single type + val single = nonAbsentTypes.first() + val singleRuntime = single.toRuntimeType() + types.add(single) + args.add(singleRuntime) + calculate(singleRuntime) + } + else -> { + // Had a union; use ANY runtime + types.addAll(nonAbsentTypes) + args.add(ANY) + calculate(ANY) + } + } + } + + /** + * Returns a pair of the return StaticType and the coercion. + * + * If the list is null, then no mapping is required. + * + * @return + */ + fun mapping(): Pair>?> { + val modifiers = mutableSetOf() + if (nullable) modifiers.add(StaticType.NULL) + if (missable) modifiers.add(StaticType.MISSING) + // If at top supertype, then return union of all accumulated types + if (supertype == ANY) { + return StaticType.unionOf(types + modifiers) to null + } + // If a collection, then return union of all accumulated types as these coercion rules are not defined by SQL. + if (supertype == STRUCT || supertype == BAG || supertype == LIST || supertype == SEXP) { + return StaticType.unionOf(types + modifiers) to null + } + // If not initialized, then return null, missing, or null|missing. + val s = supertype + if (s == null) { + val t = if (modifiers.isEmpty()) StaticType.MISSING else StaticType.unionOf(modifiers).flatten() + return t to null + } + // Otherwise, return the supertype along with the coercion mapping + val type = s.toNonNullStaticType() + val mapping = args.map { it to s } + return if (modifiers.isEmpty()) { + type to mapping + } else { + StaticType.unionOf(setOf(type) + modifiers).flatten() to mapping + } + } + + private fun calculate(type: PartiQLValueType) { + val s = supertype + // Initialize + if (s == null) { + supertype = type + return + } + // Don't bother calculating the new supertype, we've already hit `dynamic`. + if (s == ANY) return + // Lookup and set the new minimum common supertype + supertype = when { + type == ANY -> type + type == NULL || type == MISSING || s == type -> return // skip + else -> graph[s][type] ?: ANY // lookup, if missing then go to top. + } + } + + private operator fun Array.get(t: PartiQLValueType): T = get(t.ordinal) + + /** + * !! IMPORTANT !! + * + * This is duplicated from the TypeLattice because that was removed in v1.0.0. I wanted to implement this as + * a standalone component so that it is easy to merge (and later merge with CastTable) into v1.0.0. + */ + companion object { + + private operator fun Array.set(t: PartiQLValueType, value: T): Unit = this.set(t.ordinal, value) + + @JvmStatic + private val N = PartiQLValueType.values().size + + @JvmStatic + private fun edges(vararg edges: Pair): Array { + val arr = arrayOfNulls(N) + for (type in edges) { + arr[type.first] = type.second + } + return arr + } + + /** + * This table defines the rules in the SQL-99 section 9.3 BUT we don't have type constraints yet. + * + * TODO collection supertypes + * TODO datetime supertypes + */ + @JvmStatic + internal val graph: SuperGraph = run { + val graph = arrayOfNulls>(N) + for (type in PartiQLValueType.values()) { + // initialize all with empty edges + graph[type] = arrayOfNulls(N) + } + graph[ANY] = edges() + graph[NULL] = edges() + graph[MISSING] = edges() + graph[BOOL] = edges( + BOOL to BOOL + ) + graph[INT8] = edges( + INT8 to INT8, + INT16 to INT16, + INT32 to INT32, + INT64 to INT64, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[INT16] = edges( + INT8 to INT16, + INT16 to INT16, + INT32 to INT32, + INT64 to INT64, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[INT32] = edges( + INT8 to INT32, + INT16 to INT32, + INT32 to INT32, + INT64 to INT64, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[INT64] = edges( + INT8 to INT64, + INT16 to INT64, + INT32 to INT64, + INT64 to INT64, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[INT] = edges( + INT8 to INT, + INT16 to INT, + INT32 to INT, + INT64 to INT, + INT to INT, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[DECIMAL] = edges( + INT8 to DECIMAL, + INT16 to DECIMAL, + INT32 to DECIMAL, + INT64 to DECIMAL, + INT to DECIMAL, + DECIMAL to DECIMAL, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[DECIMAL_ARBITRARY] = edges( + INT8 to DECIMAL_ARBITRARY, + INT16 to DECIMAL_ARBITRARY, + INT32 to DECIMAL_ARBITRARY, + INT64 to DECIMAL_ARBITRARY, + INT to DECIMAL_ARBITRARY, + DECIMAL to DECIMAL_ARBITRARY, + DECIMAL_ARBITRARY to DECIMAL_ARBITRARY, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[FLOAT32] = edges( + INT8 to FLOAT32, + INT16 to FLOAT32, + INT32 to FLOAT32, + INT64 to FLOAT32, + INT to FLOAT32, + DECIMAL to FLOAT32, + DECIMAL_ARBITRARY to FLOAT32, + FLOAT32 to FLOAT32, + FLOAT64 to FLOAT64, + ) + graph[FLOAT64] = edges( + INT8 to FLOAT64, + INT16 to FLOAT64, + INT32 to FLOAT64, + INT64 to FLOAT64, + INT to FLOAT64, + DECIMAL to FLOAT64, + DECIMAL_ARBITRARY to FLOAT64, + FLOAT32 to FLOAT64, + FLOAT64 to FLOAT64, + ) + graph[CHAR] = edges( + CHAR to CHAR, + STRING to STRING, + SYMBOL to STRING, + CLOB to CLOB, + ) + graph[STRING] = edges( + CHAR to STRING, + STRING to STRING, + SYMBOL to STRING, + CLOB to CLOB, + ) + graph[SYMBOL] = edges( + CHAR to SYMBOL, + STRING to STRING, + SYMBOL to SYMBOL, + CLOB to CLOB, + ) + graph[BINARY] = edges( + BINARY to BINARY, + ) + graph[BYTE] = edges( + BYTE to BYTE, + BLOB to BLOB, + ) + graph[BLOB] = edges( + BYTE to BLOB, + BLOB to BLOB, + ) + graph[DATE] = edges( + DATE to DATE, + ) + graph[CLOB] = edges( + CHAR to CLOB, + STRING to CLOB, + SYMBOL to CLOB, + CLOB to CLOB, + ) + graph[TIME] = edges( + TIME to TIME, + ) + graph[TIMESTAMP] = edges( + TIMESTAMP to TIMESTAMP, + ) + graph[INTERVAL] = edges( + INTERVAL to INTERVAL, + ) + graph[LIST] = edges( + LIST to LIST, + SEXP to SEXP, + BAG to BAG, + ) + graph[SEXP] = edges( + LIST to SEXP, + SEXP to SEXP, + BAG to BAG, + ) + graph[BAG] = edges( + LIST to BAG, + SEXP to BAG, + BAG to BAG, + ) + graph[STRUCT] = edges( + STRUCT to STRUCT, + ) + graph.requireNoNulls() + } + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 42fd0b948..9ee9bfbbb 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -594,7 +594,8 @@ internal class PlanTyper( } val candidates = match.candidates.map { candidate -> val rex = toRexCall(candidate, args, isNotMissable) - val staticCall = rex.op as? Rex.Op.Call.Static ?: error("ToRexCall should always return a static call.") + val staticCall = + rex.op as? Rex.Op.Call.Static ?: error("ToRexCall should always return a static call.") val resolvedFn = staticCall.fn as? Fn.Resolved ?: error("This should have been resolved") types.add(rex.type) val coercions = candidate.mapping.map { it?.let { fnResolved(it) } } @@ -614,7 +615,11 @@ internal class PlanTyper( return rex(ANY, rexOpErr("Direct dynamic calls are not supported. This should have been a static call.")) } - private fun toRexCall(match: FnMatch.Ok, args: List, isNotMissable: Boolean): Rex { + private fun toRexCall( + match: FnMatch.Ok, + args: List, + isNotMissable: Boolean, + ): Rex { // Found a match! val newFn = fnResolved(match.signature) val newArgs = rewriteFnArgs(match.mapping, args) @@ -681,34 +686,78 @@ internal class PlanTyper( } override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex { - // Type branches and prune branches known to never execute - val newBranches = node.branches.map { visitRexOpCaseBranch(it, it.rex.type) } - .filterNot { isLiteralBool(it.condition, false) } + // Rewrite CASE-WHEN branches + val oldBranches = node.branches.toTypedArray() + val newBranches = mutableListOf() + val typer = DynamicTyper() + for (i in oldBranches.indices) { + + // Type the branch + var branch = oldBranches[i] + branch = visitRexOpCaseBranch(branch, branch.rex.type) + + // Check if branch condition is a literal + if (boolOrNull(branch.condition.op) == false) { + continue // prune + } - newBranches.forEach { branch -> - if (canBeBoolean(branch.condition.type).not()) { + // Emit typing error if a branch condition is never a boolean (prune) + if (!canBeBoolean(branch.condition.type)) { onProblem.invoke( Problem( UNKNOWN_PROBLEM_LOCATION, PlanningProblemDetails.IncompatibleTypesForOp(branch.condition.type.allTypes, "CASE_WHEN") ) ) + // prune, always false + continue } + + // Accumulate typing information + typer.accumulate(branch.rex.type) + newBranches.add(branch) } - val default = visitRex(node.default, node.default.type) - - // Calculate final expression (short-circuit to first branch if the condition is always TRUE). - val resultTypes = newBranches.map { it.rex }.map { it.type } + listOf(default.type) - return when (newBranches.size) { - 0 -> default - else -> when (isLiteralBool(newBranches[0].condition, true)) { - true -> newBranches[0].rex - false -> rex( - type = StaticType.unionOf(resultTypes.toSet()).flatten(), - node.copy(branches = newBranches, default = default) - ) + + // Rewrite ELSE branch + var newDefault = visitRex(node.default, null) + if (newBranches.isEmpty()) { + return newDefault + } + typer.accumulate(newDefault.type) + + // Compute the CASE-WHEN type from the accumulator + val (type, mapping) = typer.mapping() + + // Rewrite branches if we have coercions. + if (mapping != null) { + val msize = mapping.size + val bsize = newBranches.size + 1 + assert(msize == bsize) { "Coercion mappings `len $msize` did not match the number of CASE-WHEN branches `len $bsize`" } + // Rewrite branches + for (i in newBranches.indices) { + val (operand, target) = mapping[i] + if (operand == target) continue // skip + val cast = env.fnResolver.cast(operand, target) + val branch = newBranches[i] + val rex = rex(type, rexOpCallStatic(fnResolved(cast), listOf(branch.rex))) + newBranches[i] = branch.copy(rex = rex) + } + // Rewrite default + val (operand, target) = mapping.last() + if (operand != target) { + val cast = env.fnResolver.cast(operand, target) + newDefault = rex(type, rexOpCallStatic(fnResolved(cast), listOf(newDefault))) } } + + // TODO constant folding in planner which also means branch pruning + // This is added for backwards compatibility, we return the first branch if it's true + if (boolOrNull(newBranches[0].condition.op) == true) { + return newBranches[0].rex + } + + val op = Rex.Op.Case(newBranches, newDefault) + return rex(type, op) } /** @@ -723,11 +772,12 @@ internal class PlanTyper( } } + /** + * Returns the boolean value of the expression. For now, only handle literals. + */ @OptIn(PartiQLValueExperimental::class) - private fun isLiteralBool(rex: Rex, bool: Boolean): Boolean { - val op = rex.op as? Rex.Op.Lit ?: return false - val value = op.value as? BoolValue ?: return false - return value.value == bool + private fun boolOrNull(op: Rex.Op): Boolean? { + return if (op is Rex.Op.Lit && op.value is BoolValue) op.value.value else null } /** @@ -1242,7 +1292,8 @@ internal class PlanTyper( is Identifier.Symbol -> BindingPath(listOf(this.toBindingName())) } - private fun Identifier.Qualified.toBindingPath() = BindingPath(steps = listOf(this.root.toBindingName()) + steps.map { it.toBindingName() }) + private fun Identifier.Qualified.toBindingPath() = + BindingPath(steps = listOf(this.root.toBindingName()) + steps.map { it.toBindingName() }) private fun Identifier.Symbol.toBindingName() = BindingName( name = symbol, diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index ad8114067..de6fa6cbc 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -33,6 +33,7 @@ import org.partiql.types.BagType import org.partiql.types.ListType import org.partiql.types.SexpType import org.partiql.types.StaticType +import org.partiql.types.StaticType.Companion.MISSING import org.partiql.types.StaticType.Companion.unionOf import org.partiql.types.StructType import org.partiql.types.TupleConstraint @@ -46,7 +47,7 @@ class PlanTyperTestsPorted { sealed class TestCase { class SuccessTestCase( - val name: String, + val name: String? = null, val key: PartiQLTest.Key? = null, val query: String? = null, val catalog: String? = null, @@ -54,7 +55,12 @@ class PlanTyperTestsPorted { val expected: StaticType, val warnings: ProblemHandler? = null, ) : TestCase() { - override fun toString(): String = "$name : $query" + override fun toString(): String { + if (key != null) { + return "${key.group} : ${key.name}" + } + return "${name!!} : $query" + } } class ErrorTestCase( @@ -2401,6 +2407,171 @@ class PlanTyperTestsPorted { ) ) ), + // + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-00"), + catalog = "pql", + expected = StaticType.INT4 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-02"), + catalog = "pql", + expected = StaticType.INT4 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-03"), + catalog = "pql", + expected = StaticType.INT8 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-04"), + catalog = "pql", + expected = StaticType.INT + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-05"), + catalog = "pql", + expected = StaticType.INT + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-06"), + catalog = "pql", + expected = StaticType.INT + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-07"), + catalog = "pql", + expected = StaticType.INT8 + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-08"), + catalog = "pql", + expected = unionOf(StaticType.INT, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-09"), + catalog = "pql", + expected = unionOf(StaticType.INT, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-10"), + catalog = "pql", + expected = unionOf(StaticType.DECIMAL, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-11"), + catalog = "pql", + expected = unionOf(StaticType.INT, StaticType.NULL, StaticType.MISSING), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-12"), + catalog = "pql", + expected = StaticType.FLOAT + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-13"), + catalog = "pql", + expected = unionOf(StaticType.FLOAT, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-14"), + catalog = "pql", + expected = StaticType.STRING, + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-15"), + catalog = "pql", + expected = unionOf(StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-16"), + catalog = "pql", + expected = StaticType.CLOB, + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-17"), + catalog = "pql", + expected = unionOf(StaticType.CLOB, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-18"), + catalog = "pql", + expected = unionOf(StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-19"), + catalog = "pql", + expected = unionOf(StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-20"), + catalog = "pql", + expected = StaticType.NULL, + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-21"), + catalog = "pql", + expected = unionOf(StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-22"), + catalog = "pql", + expected = unionOf(StaticType.INT4, StaticType.NULL, StaticType.MISSING), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-23"), + catalog = "pql", + expected = StaticType.INT4, + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-24"), + catalog = "pql", + expected = unionOf(StaticType.INT4, StaticType.INT8, StaticType.STRING), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-25"), + catalog = "pql", + expected = unionOf(StaticType.INT4, StaticType.INT8, StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-26"), + catalog = "pql", + expected = unionOf(StaticType.INT4, StaticType.INT8, StaticType.STRING, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-27"), + catalog = "pql", + expected = unionOf(StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.DECIMAL, StaticType.STRING, StaticType.CLOB), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-28"), + catalog = "pql", + expected = unionOf(StaticType.INT2, StaticType.INT4, StaticType.INT8, StaticType.INT, StaticType.DECIMAL, StaticType.STRING, StaticType.CLOB, StaticType.NULL), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-29"), + catalog = "pql", + expected = unionOf( + StructType( + fields = listOf( + StructType.Field("x", StaticType.INT4), + StructType.Field("y", StaticType.INT4), + ), + ), + StructType( + fields = listOf( + StructType.Field("x", StaticType.INT8), + StructType.Field("y", StaticType.INT8), + ), + ), + StaticType.NULL, + ), + ), + SuccessTestCase( + key = PartiQLTest.Key("basics", "case-when-30"), + catalog = "pql", + expected = MISSING + ), ) @JvmStatic diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/functions/NullIfTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/functions/NullIfTest.kt index 8b8fb9fc4..6f1a56a84 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/functions/NullIfTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/functions/NullIfTest.kt @@ -8,27 +8,34 @@ import org.partiql.planner.util.cartesianProduct import org.partiql.types.StaticType import java.util.stream.Stream -// TODO: Model handling of Truth Value in typer better. +/** + * The NULLIF() function returns NULL if two expressions are equal, otherwise it returns the first expression + * + * The type of NULLIF(arg_0: T_0, arg_1: arg_1) should be (null|T_0). + * + * CASE + * WHEN x = y THEN NULL + * ELSE x + * END + * + * TODO: Model handling of Truth Value in typer better. + */ class NullIfTest : PartiQLTyperTestBase() { @TestFactory fun nullIf(): Stream { - val tests = listOf( - "func-00", - ).map { inputs.get("basics", it)!! } - val argsMap = buildMap { - val successArgs = cartesianProduct(allSupportedType, allSupportedType) + val tests = listOf("func-00").map { inputs.get("basics", it)!! } + val argsMap = mutableMapOf>>() - successArgs.forEach { args: List -> - val returnType = StaticType.unionOf(args.first(), StaticType.NULL).flatten() - (this[TestResult.Success(returnType)] ?: setOf(args)).let { - put(TestResult.Success(returnType), it + setOf(args)) - } - Unit - } - put(TestResult.Failure, emptySet>()) + // Generate all success cases + cartesianProduct(allSupportedType, allSupportedType).forEach { args -> + val expected = StaticType.unionOf(args[0], StaticType.NULL).flatten() + val result = TestResult.Success(expected) + argsMap[result] = setOf(args) } + // No failure case + argsMap[TestResult.Failure] = emptySet() return super.testGen("nullIf", tests, argsMap) } diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/pql/t_item.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/t_item.ion new file mode 100644 index 000000000..e4435d624 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/pql/t_item.ion @@ -0,0 +1,184 @@ +// simple item which various types for testing +{ + type: "struct", + constraints: [ closed, unique ], + fields: [ + // Boolean + { + name: "t_bool", + type: "bool", + }, + { + name: "t_bool_nul", + type: ["bool","null"], + }, + // Exact Numeric +// { +// name: "t_int8", +// type: "int8", +// }, +// { +// name: "t_int8_null", +// type: ["int8", "null"], +// }, + { + name: "t_int16", + type: "int16", + }, + { + name: "t_int16_null", + type: ["int16", "null"], + }, + { + name: "t_int32", + type: "int32", + }, + { + name: "t_int32_null", + type: ["int32", "null"], + }, + { + name: "t_int64", + type: "int64", + }, + { + name: "t_int64_null", + type: ["int64", "null"], + }, + { + name: "t_int", + type: "int", + }, + { + name: "t_int_null", + type: ["int", "null"], + }, + { + name: "t_decimal", + type: "decimal", + }, + { + name: "t_decimal_null", + type: ["decimal", "null"], + }, + // Approximate Numeric + { + name: "t_float32", + type: "float32", + }, + { + name: "t_float32_null", + type: ["float32", "null"], + }, + { + name: "t_float64", + type: "float64", + }, + { + name: "t_float64_null", + type: ["float64", "null"], + }, + // Strings + { + name: "t_string", + type: "string", + }, + { + name: "t_string_null", + type: ["string", "null"], + }, + { + name: "t_clob", + type: "clob", + }, + { + name: "t_clob_null", + type: ["clob", "null"], + }, + // absent + { + name: "t_null", + type: "null", + }, + { + name: "t_missing", + type: "missing", + }, + { + name: "t_absent", + type: ["null", "missing"], + }, + // collections + { + name: "t_bag", + type: { + type: "bag", + items: "any", + }, + }, + { + name: "t_list", + type: { + type: "list", + items: "any", + } + }, + { + name: "t_sexp", + type: { + type: "sexp", + items: "any", + } + }, + // structs + { + name: "t_struct_a", + type: { + type: "struct", + fields: [ + { + name: "x", + type: "int32", + }, + { + name: "y", + type: "int32", + }, + ] + }, + }, + { + name: "t_struct_b", + type: { + type: "struct", + fields: [ + { + name: "x", + type: "int64", + }, + { + name: "y", + type: "int64", + }, + ] + }, + }, + { + name: "t_any", + type: "any", + }, + // unions + { + name: "t_num_exact", + type: [ "int16", "int32", "int64", "int", "decimal" ], + }, + { + name: "t_num_exact_null", + type: [ "int16", "int32", "int64", "int", "decimal", "null" ], + }, + { + name: "t_str", + type: [ "clob", "string" ], + } + ] +} diff --git a/partiql-planner/src/testFixtures/resources/inputs/basics/case.sql b/partiql-planner/src/testFixtures/resources/inputs/basics/case.sql index f7099d53c..ba1589e80 100644 --- a/partiql-planner/src/testFixtures/resources/inputs/basics/case.sql +++ b/partiql-planner/src/testFixtures/resources/inputs/basics/case.sql @@ -1,30 +1,293 @@ ---#[case-00] +-- ----------------------------- +-- Exact Numeric +-- ----------------------------- + +--#[case-when-00] +-- type: (int32) +CASE t_item.t_bool + WHEN true THEN 0 + WHEN false THEN 1 + ELSE 2 +END; + +--#[case-when-02] +-- type: (int32) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT4) + ELSE t_item.t_int32 -- INT4 +END; + +--#[case-when-03] +-- type: (int64) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT8) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT8) + ELSE t_item.t_int64 -- INT8 +END; + +--#[case-when-04] +-- type: (int) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + WHEN 'c' THEN t_item.t_int64 -- cast(.. AS INT) + ELSE t_item.t_int -- INT +END; + +--#[case-when-05] +-- type: (int) +CASE t_item.t_string + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + WHEN 'c' THEN t_item.t_int64 -- cast(.. AS INT) + ELSE t_item.t_int -- INT +END; + +--#[case-when-06] +-- type: (int) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + ELSE t_item.t_int -- INT +END; + +--#[case-when-07] +-- type: (int64) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 -- cast(.. AS INT8) + WHEN 'b' THEN t_item.t_int64 -- INT8 + ELSE t_item.t_int16 -- cast(.. AS INT8) +END; + +--#[case-when-08] +-- type: (int|null) +-- nullable default +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16 -- cast(.. AS INT) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + ELSE t_item.t_int_null -- INT +END; + +--#[case-when-09] +-- type: (int|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int16_null -- cast(.. AS INT) + WHEN 'b' THEN t_item.t_int32 -- cast(.. AS INT) + ELSE t_item.t_int +END; + +--#[case-when-10] +-- type: (decimal|null) +-- nullable branch +CASE t_item.t_string + WHEN 'a' THEN t_item.t_decimal + WHEN 'b' THEN t_item.t_int32 + ELSE NULL +END; + +--#[case-when-11] +-- type: (int|null|missing) +-- TODO should really be (int|missing) but our translation of coalesce doesn't consider types. +COALESCE(CAST(t_item.t_string AS INT), 1); + +-- ----------------------------- +-- Approximate Numeric +-- ----------------------------- + +-- TODO model approximate numeric +-- We do not have the appropriate StaticType for this. + +--#[case-when-12] +-- type: (float64) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int + ELSE t_item.t_float64 +END; + +--#[case-when-13] +-- type: (float64|null) +-- nullable branch +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int + WHEN 'b' THEN t_item.t_float64 + ELSE NULL +END; + +-- ----------------------------- +-- Character Strings +-- ----------------------------- + +--#[case-when-14] +-- type: string +CASE t_item.t_string + WHEN 'a' THEN t_item.t_string + ELSE 'default' +END; + +--#[case-when-15] +-- type: (string|null) +-- null default +CASE t_item.t_string + WHEN 'a' THEN t_item.t_string + ELSE NULL +END; + +--#[case-when-16] +-- type: clob +CASE t_item.t_string + WHEN 'a' THEN t_item.t_string + WHEN 'b' THEN t_item.t_clob + ELSE 'default' +END; + +--#[case-when-17] +-- type: (clob|null) +-- null default +CASE t_item.t_string + WHEN 'a' THEN t_item.t_string + WHEN 'b' THEN t_item.t_clob + ELSE NULL +END; + +-- ---------------------------------- +-- Variations of null and missing +-- ---------------------------------- + +--#[case-when-18] +-- type: (string|null) +CASE t_item.t_string + WHEN 'a' THEN NULL + ELSE 'default' +END; + +--#[case-when-19] +-- type: (string|null) +CASE t_item.t_string + WHEN 'a' THEN NULL + WHEN 'b' THEN NULL + WHEN 'c' THEN NULL + WHEN 'd' THEN NULL + ELSE 'default' +END; + +--#[case-when-20] +-- type: null +-- no default, null anyways +CASE t_item.t_string + WHEN 'a' THEN NULL +END; + +--#[case-when-21] +-- type: (string|null) +-- no default +CASE t_item.t_string + WHEN 'a' THEN 'ok!' +END; + +--#[case-when-22] +-- type: (null|missing|int32) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_absent + ELSE -1 +END; + +--#[case-when-23] +-- type: int32 +-- false branch is pruned +CASE + WHEN false THEN t_item.t_absent + ELSE -1 +END; + +-- ----------------------------- +-- Heterogeneous Branches +-- ----------------------------- + +--#[case-when-24] +-- type: (int32|int64|string) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 + WHEN 'b' THEN t_item.t_int64 + ELSE 'default' +END; + +--#[case-when-25] +-- type: (int32|int64|string|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 + WHEN 'b' THEN t_item.t_int64 + WHEN 'c' THEN t_item.t_string + ELSE NULL +END; + +--#[case-when-26] +-- type: (int32|int64|string|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_int32 + WHEN 'b' THEN t_item.t_int64_null + ELSE 'default' +END; + +--#[case-when-27] +-- type: (int16|int32|int64|int|decimal|string|clob) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_num_exact + WHEN 'b' THEN t_item.t_str + ELSE 'default' +END; + +--#[case-when-28] +-- type: (int16|int32|int64|int|decimal|string|clob|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_num_exact + WHEN 'b' THEN t_item.t_str +END; + +--#[case-when-29] +-- type: (struct_a|struct_b|null) +CASE t_item.t_string + WHEN 'a' THEN t_item.t_struct_a + WHEN 'b' THEN t_item.t_struct_b +END; + +--#[case-when-30] +-- type: missing +CASE t_item.t_string + WHEN 'a' THEN MISSING + WHEN 'b' THEN MISSING + ELSE MISSING +END; + +-- ----------------------------- +-- (Unused) old tests +-- ----------------------------- + +--#[old-case-when-00] CASE WHEN FALSE THEN 0 WHEN TRUE THEN 1 ELSE 2 END; ---#[case-01] +--#[old-case-when-01] CASE WHEN 1 = 2 THEN 0 WHEN 2 = 3 THEN 1 ELSE 3 END; ---#[case-02] +--#[old-case-when-02] CASE 1 WHEN 1 THEN 'MATCH!' ELSE 'NO MATCH!' END; ---#[case-03] +--#[old-case-when-03] CASE 'Hello World' WHEN 'Hello World' THEN TRUE ELSE FALSE END; ---#[case-04] +--#[old-case-when-04] SELECT CASE a WHEN TRUE THEN 'a IS TRUE' @@ -32,7 +295,7 @@ SELECT END AS result FROM T; ---#[case-05] +--#[old-case-when-05] SELECT CASE WHEN a = TRUE THEN 'a IS TRUE' @@ -40,7 +303,7 @@ SELECT END AS result FROM T; ---#[case-06] +--#[old-case-when-06] SELECT CASE b WHEN 10 THEN 'b IS 10' @@ -48,7 +311,7 @@ SELECT END AS result FROM T; ---#[case-07] +--#[old-case-when-07] -- TODO: This is currently failing as we seemingly cannot search for a nested attribute of a global. SELECT CASE d.e @@ -57,7 +320,7 @@ SELECT END AS result FROM T; ---#[case-08] +--#[old-case-when-08] SELECT CASE x WHEN 'WATER' THEN 'x IS WATER' @@ -66,7 +329,7 @@ SELECT END AS result FROM T; ---#[case-09] +--#[old-case-when-09] -- TODO: When using `x IS STRING` or `x IS DECIMAL`, I found that there are issues with the SqlCalls not receiving -- the length/precision/scale parameters. This doesn't have to do with CASE_WHEN, but it needs to be addressed. SELECT @@ -77,7 +340,7 @@ SELECT END AS result FROM T; ---#[case-10] +--#[old-case-when-10] CASE WHEN FALSE THEN 0 WHEN FALSE THEN 1