diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt index 038f4edb0..912da8080 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamic.kt @@ -103,7 +103,10 @@ internal class ExprCallDynamic( exactMatches = currentExactMatches } } - return if (currentMatch == null) null else Candidate(functions[currentMatch!!].getInstance(args.toTypedArray())) + return if (currentMatch == null) null else { + val instance = functions[currentMatch!!].getInstance(args.toTypedArray()) ?: return null + Candidate(instance) + } } /** diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/PTestCase.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/PTestCase.kt new file mode 100644 index 000000000..3811fe8b3 --- /dev/null +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/PTestCase.kt @@ -0,0 +1,8 @@ +package org.partiql.eval + +interface PTestCase : Runnable { + /** + * Executes the test case + */ + override fun run() +} diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt index bef79ef0c..b8f1960ee 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEvaluatorTest.kt @@ -8,18 +8,6 @@ import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource import org.partiql.eval.Mode import org.partiql.eval.compiler.PartiQLCompiler -import org.partiql.parser.PartiQLParser -import org.partiql.plan.Plan -import org.partiql.planner.PartiQLPlanner -import org.partiql.spi.catalog.Catalog -import org.partiql.spi.catalog.Name -import org.partiql.spi.catalog.Session -import org.partiql.spi.catalog.Table -import org.partiql.spi.value.Datum -import org.partiql.spi.value.DatumReader -import org.partiql.types.PType -import org.partiql.types.StaticType -import org.partiql.types.fromStaticType import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.bagValue @@ -27,16 +15,12 @@ import org.partiql.value.boolValue import org.partiql.value.decimalValue import org.partiql.value.int32Value import org.partiql.value.int64Value -import org.partiql.value.io.PartiQLValueIonWriterBuilder import org.partiql.value.listValue import org.partiql.value.missingValue import org.partiql.value.nullValue import org.partiql.value.stringValue import org.partiql.value.structValue -import java.io.ByteArrayOutputStream import java.math.BigDecimal -import kotlin.test.assertEquals -import kotlin.test.assertNotNull /** * This holds sanity tests during the development of the [PartiQLCompiler.standard] implementation. @@ -47,37 +31,37 @@ class PartiQLEvaluatorTest { @ParameterizedTest @MethodSource("sanityTestsCases") @Execution(ExecutionMode.CONCURRENT) - fun sanityTests(tc: SuccessTestCase) = tc.assert() + fun sanityTests(tc: SuccessTestCase) = tc.run() @ParameterizedTest @MethodSource("typingModeTestCases") @Execution(ExecutionMode.CONCURRENT) - fun typingModeTests(tc: TypingTestCase) = tc.assert() + fun typingModeTests(tc: TypingTestCase) = tc.run() @ParameterizedTest @MethodSource("subqueryTestCases") @Execution(ExecutionMode.CONCURRENT) - fun subqueryTests(tc: SuccessTestCase) = tc.assert() + fun subqueryTests(tc: SuccessTestCase) = tc.run() @ParameterizedTest @MethodSource("aggregationTestCases") @Execution(ExecutionMode.CONCURRENT) - fun aggregationTests(tc: SuccessTestCase) = tc.assert() + fun aggregationTests(tc: SuccessTestCase) = tc.run() @ParameterizedTest @MethodSource("joinTestCases") @Execution(ExecutionMode.CONCURRENT) - fun joinTests(tc: SuccessTestCase) = tc.assert() + fun joinTests(tc: SuccessTestCase) = tc.run() @ParameterizedTest @MethodSource("globalsTestCases") @Execution(ExecutionMode.CONCURRENT) - fun globalsTests(tc: SuccessTestCase) = tc.assert() + fun globalsTests(tc: SuccessTestCase) = tc.run() @ParameterizedTest @MethodSource("castTestCases") @Execution(ExecutionMode.CONCURRENT) - fun castTests(tc: SuccessTestCase) = tc.assert() + fun castTests(tc: SuccessTestCase) = tc.run() companion object { @@ -631,14 +615,14 @@ class PartiQLEvaluatorTest { structValue( "sensor" to int32Value(1), "readings" to bagValue( - org.partiql.value.decimalValue(0.4.toBigDecimal()), - org.partiql.value.decimalValue(0.2.toBigDecimal()) + decimalValue(0.4.toBigDecimal()), + decimalValue(0.2.toBigDecimal()) ) ), structValue( "sensor" to int32Value(2), "readings" to bagValue( - org.partiql.value.decimalValue(0.3.toBigDecimal()) + decimalValue(0.3.toBigDecimal()) ) ), ) @@ -1291,148 +1275,6 @@ class PartiQLEvaluatorTest { ) } - public class SuccessTestCase @OptIn(PartiQLValueExperimental::class) constructor( - val input: String, - val expected: PartiQLValue, - val mode: Mode = Mode.PERMISSIVE(), - val globals: List = emptyList(), - ) { - - private val compiler = PartiQLCompiler.standard() - private val parser = PartiQLParser.standard() - private val planner = PartiQLPlanner.standard() - - /** - * @property value is a serialized Ion value. - */ - class Global( - val name: String, - val value: String, - val type: StaticType = StaticType.ANY, - ) - - internal fun assert() { - val parseResult = parser.parse(input) - assertEquals(1, parseResult.statements.size) - val statement = parseResult.statements[0] - val catalog = Catalog.builder() - .name("memory") - .apply { - globals.forEach { - val table = Table.standard( - name = Name.of(it.name), - schema = fromStaticType(it.type), - datum = DatumReader.ion(it.value.byteInputStream()).next()!! - ) - define(table) - } - } - .build() - val session = Session.builder() - .catalog("memory") - .catalogs(catalog) - .build() - val plan = planner.plan(statement, session).plan - val result = compiler.prepare(plan, mode).execute() - val output = result.toPartiQLValue() // TODO: Assert directly on Datum - assert(expected == output) { - comparisonString(expected, output, plan) - } - } - - @OptIn(PartiQLValueExperimental::class) - private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: Plan): String { - val expectedBuffer = ByteArrayOutputStream() - val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer) - expectedWriter.append(expected) - return buildString { - // TODO pretty-print V1 plans! - appendLine(plan) - appendLine("Expected : $expectedBuffer") - expectedBuffer.reset() - expectedWriter.append(actual) - appendLine("Actual : $expectedBuffer") - } - } - - override fun toString(): String { - return input - } - } - - public class TypingTestCase @OptIn(PartiQLValueExperimental::class) constructor( - val name: String, - val input: String, - val expectedPermissive: PartiQLValue, - ) { - - private val compiler = PartiQLCompiler.standard() - private val parser = PartiQLParser.standard() - private val planner = PartiQLPlanner.standard() - - internal fun assert() { - val (permissiveResult, plan) = run(mode = Mode.PERMISSIVE()) - val permissiveResultPValue = permissiveResult.toPartiQLValue() - val assertionCondition = try { - expectedPermissive == permissiveResultPValue // TODO: Assert using Datum - } catch (t: Throwable) { - val str = buildString { - appendLine("Test Name: $name") - // TODO pretty-print V1 plans! - appendLine(plan) - } - throw RuntimeException(str, t) - } - assert(assertionCondition) { - comparisonString(expectedPermissive, permissiveResultPValue, plan) - } - var error: Throwable? = null - try { - val (strictResult, _) = run(mode = Mode.STRICT()) - when (strictResult.type.kind) { - PType.Kind.BAG, PType.Kind.ARRAY -> strictResult.toList() - else -> strictResult - } - } catch (e: Throwable) { - error = e - } - assertNotNull(error) - } - - private fun run(mode: Mode): Pair { - val parseResult = parser.parse(input) - assertEquals(1, parseResult.statements.size) - val statement = parseResult.statements[0] - val catalog = Catalog.builder().name("memory").build() - val session = Session.builder() - .catalog("memory") - .catalogs(catalog) - .build() - val plan = planner.plan(statement, session).plan - val result = compiler.prepare(plan, mode).execute() - return result to plan - } - - @OptIn(PartiQLValueExperimental::class) - private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: Plan): String { - val expectedBuffer = ByteArrayOutputStream() - val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer) - expectedWriter.append(expected) - return buildString { - // TODO pretty-print V1 plans! - appendLine(plan) - appendLine("Expected : $expectedBuffer") - expectedBuffer.reset() - expectedWriter.append(actual) - appendLine("Actual : $expectedBuffer") - } - } - - override fun toString(): String { - return "$name -- $input" - } - } - @Test @Disabled fun developmentTest() { @@ -1456,7 +1298,7 @@ class PartiQLEvaluatorTest { SuccessTestCase.Global("d", "3.") ) ) - tc.assert() + tc.run() } @Test @@ -1469,7 +1311,7 @@ class PartiQLEvaluatorTest { "v" to int32Value(5) ) ) - ).assert() + ).run() @Test @Disabled("This is just a placeholder. We should add support for this. Grouping is not yet supported.") @@ -1480,7 +1322,7 @@ class PartiQLEvaluatorTest { PLACEHOLDER FOR THE EXAMPLE IN THE RELEVANT SECTION. GROUPING NOT YET SUPPORTED. """.trimIndent(), expectedPermissive = missingValue() - ).assert() + ).run() @Test @Disabled("The planner fails this, though it should pass for permissive mode.") @@ -1494,7 +1336,7 @@ class PartiQLEvaluatorTest { "n" to stringValue("_1") ) ) - ).assert() + ).run() @Test @Disabled("We don't yet support arrays.") @@ -1512,7 +1354,7 @@ class PartiQLEvaluatorTest { missingValue() ) ) - ).assert() + ).run() @Test @Disabled("There is a bug in the planner which makes this always return missing.") @@ -1521,7 +1363,7 @@ class PartiQLEvaluatorTest { name = "PartiQL Specification Section 4.2 -- non integer index", input = "SELECT VALUE [1,2,3][v] FROM <<1, 1.0>> AS v;", expectedPermissive = bagValue(int32Value(2), missingValue()) - ).assert() + ).run() @Test @Disabled("CASTs aren't supported yet.") @@ -1538,7 +1380,7 @@ class PartiQLEvaluatorTest { "a" to int32Value(6), ), ) - ).assert() + ).run() @Test @Disabled("Arrays aren't supported yet.") @@ -1559,7 +1401,7 @@ class PartiQLEvaluatorTest { "i" to int64Value(2), ), ) - ).assert() + ).run() @Test @Disabled( @@ -1608,7 +1450,7 @@ class PartiQLEvaluatorTest { ), ), mode = Mode.PERMISSIVE() - ).assert() + ).run() // PartiQL Specification Section 8 @Test @@ -1617,7 +1459,7 @@ class PartiQLEvaluatorTest { SuccessTestCase( input = "MISSING AND TRUE;", expected = boolValue(null), - ).assert() + ).run() // PartiQL Specification Section 8 @Test @@ -1626,7 +1468,7 @@ class PartiQLEvaluatorTest { input = "MISSING AND TRUE;", expected = boolValue(null), // TODO: Is this right? mode = Mode.STRICT() - ).assert() + ).run() @Test @Disabled("Support for ORDER BY needs to be added for this to pass.") @@ -1637,7 +1479,7 @@ class PartiQLEvaluatorTest { (4, 5) < (SELECT VALUE t.a FROM << { 'a': 3 }, { 'a': 4 } >> AS t ORDER BY t.a) """.trimIndent(), expected = boolValue(false) - ).assert() + ).run() @Test @Disabled("This is appropriately coerced, but this test is failing because LT currently doesn't support LISTS.") @@ -1647,7 +1489,7 @@ class PartiQLEvaluatorTest { (4, 5) < (SELECT t.a, t.a FROM << { 'a': 3 } >> AS t) """.trimIndent(), expected = boolValue(false) - ).assert() + ).run() @Test @Disabled("This broke in its introduction to the codebase on merge. See 5fb9a1ccbc7e630b0df62aa8b161d319c763c1f6.") @@ -1691,7 +1533,7 @@ class PartiQLEvaluatorTest { stringValue("John"), stringValue("Doe"), stringValue("Zoe"), stringValue("Bill") ) ) - ).assert() + ).run() @Test @Disabled("This broke in its introduction to the codebase on merge. See 5fb9a1ccbc7e630b0df62aa8b161d319c763c1f6.") @@ -1754,5 +1596,5 @@ class PartiQLEvaluatorTest { stringValue("John"), stringValue("Doe"), stringValue("Zoe"), stringValue("Bill") ) ) - ).assert() + ).run() } diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PlusTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PlusTest.kt new file mode 100644 index 000000000..d41117887 --- /dev/null +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PlusTest.kt @@ -0,0 +1,88 @@ +package org.partiql.eval.internal + +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.partiql.eval.Mode +import org.partiql.spi.value.Datum +import java.math.BigDecimal + +class PlusTest { + + @ParameterizedTest + @MethodSource("plusTestCases") + @Execution(ExecutionMode.CONCURRENT) + fun plusTests(tc: SuccessTestCase) = tc.run() + + companion object { + + // Result precision: max(s1, s2) + max(p1 - s1, p2 - s2) + 1 + // Result scale: max(s1, s2) + @JvmStatic + fun plusTestCases() = listOf( + SuccessTestCase( + input = """ + -- DEC(1, 0) + DEC(6, 5) + -- P = 5 + MAX(1, 1) + 1 = 7 + -- S = MAX(0, 5) = 5 + 1 + 2.00000; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(300000, 5), 7, 5), + jvmEquality = true + ), + SuccessTestCase( + input = """ + -- DEC(2, 1) + DEC(6, 5) + -- P = 5 + MAX(1, 1) + 1 = 7 + -- S = MAX(1, 5) = 5 + 1.0 + 2.00000; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(300000, 5), 7, 5), + jvmEquality = true + ), + SuccessTestCase( + input = """ + -- DEC(5, 4) + DEC(6, 5) + -- P = 5 + MAX(1, 1) + 1 = 7 + -- S = MAX(4, 5) = 5 + 1.0000 + 2.00000; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(300000, 5), 7, 5), + jvmEquality = true + ), + SuccessTestCase( + input = """ + -- DEC(7, 4) + DEC(13, 7) + -- P = 7 + MAX(3, 6) + 1 = 14 + -- S = MAX(4, 7) = 7 + 234.0000 + 456789.0000000; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(457023), 14, 7), + jvmEquality = true + ), + SuccessTestCase( + input = """ + -- This shows that the value, while dynamic, still produces the right precision/scale + -- DEC(7, 4) + DEC(13, 7) + -- P = 7 + MAX(3, 6) + 1 = 14 + -- S = MAX(4, 7) = 7 + 234.0000 + dynamic_decimal; + """.trimIndent(), + mode = Mode.STRICT(), + expected = Datum.decimal(BigDecimal.valueOf(457023), 14, 7), + globals = listOf( + SuccessTestCase.Global( + "dynamic_decimal", + "456789.0000000" + ) + ), + jvmEquality = true + ), + ) + } +} diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/SuccessTestCase.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/SuccessTestCase.kt new file mode 100644 index 000000000..8da22b4d1 --- /dev/null +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/SuccessTestCase.kt @@ -0,0 +1,95 @@ +package org.partiql.eval.internal + +import org.partiql.eval.Mode +import org.partiql.eval.PTestCase +import org.partiql.eval.compiler.PartiQLCompiler +import org.partiql.parser.PartiQLParser +import org.partiql.plan.Plan +import org.partiql.planner.PartiQLPlanner +import org.partiql.spi.catalog.Catalog +import org.partiql.spi.catalog.Name +import org.partiql.spi.catalog.Session +import org.partiql.spi.catalog.Table +import org.partiql.spi.value.Datum +import org.partiql.spi.value.DatumReader +import org.partiql.types.StaticType +import org.partiql.types.fromStaticType +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import kotlin.test.assertEquals + +public class SuccessTestCase( + val input: String, + val expected: Datum, + val mode: Mode = Mode.PERMISSIVE(), + val globals: List = emptyList(), + val jvmEquality: Boolean = false +) : PTestCase { + + @OptIn(PartiQLValueExperimental::class) + constructor( + input: String, + expected: PartiQLValue, + mode: Mode = Mode.PERMISSIVE(), + globals: List = emptyList(), + ) : this(input, Datum.of(expected), mode, globals) + + private val compiler = PartiQLCompiler.standard() + private val parser = PartiQLParser.standard() + private val planner = PartiQLPlanner.standard() + + /** + * @property value is a serialized Ion value. + */ + class Global( + val name: String, + val value: String, + val type: StaticType = StaticType.ANY, + ) + + override fun run() { + val parseResult = parser.parse(input) + assertEquals(1, parseResult.statements.size) + val statement = parseResult.statements[0] + val catalog = Catalog.builder() + .name("memory") + .apply { + globals.forEach { + val table = Table.standard( + name = Name.of(it.name), + schema = fromStaticType(it.type), + datum = DatumReader.ion(it.value.byteInputStream()).next()!! + ) + define(table) + } + } + .build() + val session = Session.builder() + .catalog("memory") + .catalogs(catalog) + .build() + val plan = planner.plan(statement, session).plan + val result = compiler.prepare(plan, mode).execute() + val comparison = when (jvmEquality) { + true -> expected == result + false -> Datum.comparator().compare(expected, result) == 0 + } + assert(comparison) { + comparisonString(expected, result, plan) + } + } + + private fun comparisonString(expected: Datum, actual: Datum, plan: Plan): String { + return buildString { + // TODO pretty-print V1 plans! + appendLine(plan) + // TODO: Add DatumWriter + appendLine("Expected : $expected") + appendLine("Actual : $actual") + } + } + + override fun toString(): String { + return input + } +} diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/TypingTestCase.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/TypingTestCase.kt new file mode 100644 index 000000000..1e9d5a646 --- /dev/null +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/TypingTestCase.kt @@ -0,0 +1,92 @@ +package org.partiql.eval.internal + +import org.partiql.eval.Mode +import org.partiql.eval.PTestCase +import org.partiql.eval.compiler.PartiQLCompiler +import org.partiql.parser.PartiQLParser +import org.partiql.plan.Plan +import org.partiql.planner.PartiQLPlanner +import org.partiql.spi.catalog.Catalog +import org.partiql.spi.catalog.Session +import org.partiql.spi.value.Datum +import org.partiql.types.PType +import org.partiql.value.PartiQLValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.io.PartiQLValueIonWriterBuilder +import java.io.ByteArrayOutputStream +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +public class TypingTestCase @OptIn(PartiQLValueExperimental::class) constructor( + val name: String, + val input: String, + val expectedPermissive: PartiQLValue, +) : PTestCase { + + private val compiler = PartiQLCompiler.standard() + private val parser = PartiQLParser.standard() + private val planner = PartiQLPlanner.standard() + + @OptIn(PartiQLValueExperimental::class) + override fun run() { + val (permissiveResult, plan) = run(mode = Mode.PERMISSIVE()) + val permissiveResultPValue = permissiveResult.toPartiQLValue() + val assertionCondition = try { + expectedPermissive == permissiveResultPValue // TODO: Assert using Datum + } catch (t: Throwable) { + val str = buildString { + appendLine("Test Name: $name") + // TODO pretty-print V1 plans! + appendLine(plan) + } + throw RuntimeException(str, t) + } + assert(assertionCondition) { + comparisonString(expectedPermissive, permissiveResultPValue, plan) + } + var error: Throwable? = null + try { + val (strictResult, _) = run(mode = Mode.STRICT()) + when (strictResult.type.kind) { + PType.Kind.BAG, PType.Kind.ARRAY -> strictResult.toList() + else -> strictResult + } + } catch (e: Throwable) { + error = e + } + assertNotNull(error) + } + + private fun run(mode: Mode): Pair { + val parseResult = parser.parse(input) + assertEquals(1, parseResult.statements.size) + val statement = parseResult.statements[0] + val catalog = Catalog.builder().name("memory").build() + val session = Session.builder() + .catalog("memory") + .catalogs(catalog) + .build() + val plan = planner.plan(statement, session).plan + val result = compiler.prepare(plan, mode).execute() + return result to plan + } + + @OptIn(PartiQLValueExperimental::class) + private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: Plan): String { + val expectedBuffer = ByteArrayOutputStream() + val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer) + expectedWriter.append(expected) + return buildString { + // TODO pretty-print V1 plans! + appendLine(plan) + appendLine("Expected : $expectedBuffer") + expectedBuffer.reset() + expectedWriter.append(actual) + appendLine("Actual : $expectedBuffer") + } + } + + override fun toString(): String { + return "$name -- $input" + } +} diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamicTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamicTest.kt index 387e1e936..72735cc10 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamicTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/operator/rex/ExprCallDynamicTest.kt @@ -80,6 +80,7 @@ class ExprCallDynamicTest { override fun getInstance(args: Array): Function.Instance { return object : Function.Instance( + name = "example", returns = PType.integer(), parameters = arrayOf(it.first.toPType(), it.second.toPType()) ) { diff --git a/partiql-plan/src/main/kotlin/org/partiql/plan/builder/PlanFactory.kt b/partiql-plan/src/main/kotlin/org/partiql/plan/builder/PlanFactory.kt index 788bb2cbf..ed9b9eaab 100644 --- a/partiql-plan/src/main/kotlin/org/partiql/plan/builder/PlanFactory.kt +++ b/partiql-plan/src/main/kotlin/org/partiql/plan/builder/PlanFactory.kt @@ -391,9 +391,10 @@ public interface PlanFactory { /** * Create a [RexCallDynamic] instance. * - * @param functions - * @param args - * @return + * @param name TODO + * @param functions TODO + * @param args TODO + * @return TODO */ public fun rexCallDynamic(name: String, functions: List, args: List): RexCallDynamic = RexCallDynamicImpl(name, functions, args) diff --git a/partiql-plan/src/main/kotlin/org/partiql/plan/rex/RexCallDynamic.kt b/partiql-plan/src/main/kotlin/org/partiql/plan/rex/RexCallDynamic.kt index b48579012..2c85dee79 100644 --- a/partiql-plan/src/main/kotlin/org/partiql/plan/rex/RexCallDynamic.kt +++ b/partiql-plan/src/main/kotlin/org/partiql/plan/rex/RexCallDynamic.kt @@ -2,6 +2,7 @@ package org.partiql.plan.rex import org.partiql.plan.Visitor import org.partiql.spi.function.Function +import org.partiql.types.PType /** * Logical operator for a dynamic dispatch call. @@ -35,10 +36,11 @@ internal class RexCallDynamicImpl( private var name: String, private var functions: List, private var args: List, + type: PType = PType.dynamic() ) : RexCallDynamic { // DO NOT USE FINAL - private var _type: RexType = RexType.dynamic() + private var _type: RexType = RexType(type) override fun getName(): String = name diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/CoercionFamily.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/CoercionFamily.kt new file mode 100644 index 000000000..3fb8ee34e --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/CoercionFamily.kt @@ -0,0 +1,84 @@ +package org.partiql.planner.internal + +import org.partiql.types.PType + +/** + * This represents SQL:1999 Section 4.1.2 "Type conversions and mixing of data types" and breaks down the different + * coercion groups. + * + * TODO: [UNKNOWN] should likely be removed in the future. However, it is needed due to literal nulls and missings. + * TODO: [DYNAMIC] should likely be removed in the future. This is currently only kept to map function signatures. + */ +internal enum class CoercionFamily { + NUMBER, + STRING, + BINARY, + BOOLEAN, + STRUCTURE, + DATE, + TIME, + TIMESTAMP, + COLLECTION, + UNKNOWN, + DYNAMIC; + + companion object { + + /** + * Gets the coercion family for the given [PType.Kind]. + * + * @see CoercionFamily + * @see PType.Kind + * @see family + */ + @JvmStatic + fun family(type: PType.Kind): CoercionFamily { + return when (type) { + PType.Kind.TINYINT -> NUMBER + PType.Kind.SMALLINT -> NUMBER + PType.Kind.INTEGER -> NUMBER + PType.Kind.NUMERIC -> NUMBER + PType.Kind.BIGINT -> NUMBER + PType.Kind.REAL -> NUMBER + PType.Kind.DOUBLE -> NUMBER + PType.Kind.DECIMAL -> NUMBER + PType.Kind.STRING -> STRING + PType.Kind.BOOL -> BOOLEAN + PType.Kind.TIMEZ -> TIME + PType.Kind.TIME -> TIME + PType.Kind.TIMESTAMPZ -> TIMESTAMP + PType.Kind.TIMESTAMP -> TIMESTAMP + PType.Kind.DATE -> DATE + PType.Kind.STRUCT -> STRUCTURE + PType.Kind.ARRAY -> COLLECTION + PType.Kind.BAG -> COLLECTION + PType.Kind.ROW -> STRUCTURE + PType.Kind.CHAR -> STRING + PType.Kind.VARCHAR -> STRING + PType.Kind.DYNAMIC -> DYNAMIC // TODO: REMOVE + PType.Kind.BLOB -> BINARY + PType.Kind.CLOB -> STRING + PType.Kind.UNKNOWN -> UNKNOWN // TODO: REMOVE + PType.Kind.VARIANT -> UNKNOWN // TODO: HANDLE VARIANT + } + } + + /** + * Determines if the [from] type can be coerced to the [to] type. + * + * @see CoercionFamily + * @see PType + * @see family + */ + @JvmStatic + fun canCoerce(from: PType, to: PType): Boolean { + if (from.kind == PType.Kind.UNKNOWN) { + return true + } + if (to.kind == PType.Kind.DYNAMIC) { + return true + } + return family(from.kind) == family(to.kind) + } + } +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index 488a98cab..da8e6ed8f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -134,7 +134,6 @@ internal class Env(private val session: Session) { // 2. Search along the PATH. // TODO - val match = FnResolver.resolve(variants, args.map { it.type }) // If Type mismatch, then we return a missingOp whose trace is all possible candidates. if (match == null) { @@ -148,30 +147,24 @@ internal class Env(private val session: Session) { fn = refFn( catalog = catalog.getName(), name = Name.of(name), - signature = it.function, + signature = it, ), - coercions = it.mapping.toList(), + coercions = emptyList(), // TODO: Remove this from the plan ) } // Rewrite as a dynamic call to be typed by PlanTyper Rex(CompilerType(PType.dynamic()), Rex.Op.Call.Dynamic(args, candidates)) } is FnMatch.Static -> { - // Create an internal typed reference - val ref = refFn( - catalog = catalog.getName(), - name = Name.of(name), - signature = match.function, - ) // Apply the coercions as explicit casts val coercions: List = args.mapIndexed { i, arg -> when (val cast = match.mapping[i]) { null -> arg - else -> Rex(CompilerType(PType.dynamic()), Rex.Op.Cast.Resolved(cast, arg)) + else -> Rex(cast.target, Rex.Op.Cast.Resolved(cast, arg)) } } // Rewrite as a static call to be typed by PlanTyper - Rex(CompilerType(PType.dynamic()), Rex.Op.Call.Static(ref, coercions)) + Rex(CompilerType(PType.dynamic()), Rex.Op.Call.Static(match.function, coercions)) } } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt index 31cb15d43..726026d6b 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnMatch.kt @@ -16,7 +16,7 @@ internal sealed class FnMatch { * @property mapping */ class Static( - val function: Function, + val function: Function.Instance, val mapping: Array, ) : FnMatch() { @@ -51,5 +51,5 @@ internal sealed class FnMatch { * * @property candidates Ordered list of potentially applicable functions to dispatch dynamically. */ - class Dynamic(val candidates: List) : FnMatch() + class Dynamic(val candidates: List) : FnMatch() } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt index 43253d3aa..e773886b6 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/FnResolver.kt @@ -42,15 +42,15 @@ internal object FnResolver { // 1. Look for exact match for (candidate in candidates) { if (candidate.matchesExactly(args)) { - return FnMatch.Static(candidate, arrayOfNulls(args.size)) + val fn = candidate.getInstance(args.toTypedArray()) ?: error("This shouldn't have happened. Matching exactly should produce a function instance.") + return FnMatch.Static(fn, arrayOfNulls(args.size)) } } // 2. If there are DYNAMIC arguments, return all candidates val isDynamic = args.any { it.kind == Kind.DYNAMIC } if (isDynamic) { - val matches = match(candidates, args).ifEmpty { return null } - val orderedMatches = matches.sortedWith(MatchResultComparator).map { it.match } + val orderedMatches = candidates.sortedWith(FnComparator) return FnMatch.Dynamic(orderedMatches) } @@ -62,13 +62,17 @@ internal object FnResolver { // 3. Discard functions that cannot be matched (via implicit coercion or exact matches) val invocableMatches = match(candidates, args).ifEmpty { return null } if (invocableMatches.size == 1) { - return invocableMatches.first().match + val match = invocableMatches.first() + val fn = match.match.getInstance(args.toTypedArray()) ?: return null + return FnMatch.Static(fn, match.mapping) } // 4. Run through all candidates and keep those with the most exact matches on input types. val matches = matchOn(invocableMatches) { it.numberOfExactInputTypes } if (matches.size == 1) { - return matches.first().match + val match = matches.first() + val fn = match.match.getInstance(args.toTypedArray()) ?: return null + return FnMatch.Static(fn, match.mapping) } // TODO: Do we care about preferred types? This is a PostgreSQL concept. @@ -76,7 +80,10 @@ internal object FnResolver { // 6. Find the highest precedence one. NOTE: This is a remnant of the previous implementation. Whether we want // to keep this is up to us. - return matches.sortedWith(MatchResultComparator).first().match + val match = matches.sortedWith(MatchResultComparator).first() + val fn = match.match + val instance = fn.getInstance(args.toTypedArray()) ?: return null + return FnMatch.Static(instance, match.mapping) } /** @@ -117,11 +124,12 @@ internal object FnResolver { * Check if this function accepts the exact input argument types. Assume same arity. */ private fun Function.matchesExactly(args: List): Boolean { - val parameters = getParameters() + val instance = getInstance(args.toTypedArray()) ?: return false + val parameters = instance.parameters for (i in args.indices) { val a = args[i] val p = parameters[i] - if (p.getMatch(a) != a) return false + if (p != a) return false } return true } @@ -133,7 +141,8 @@ internal object FnResolver { * @return */ private fun Function.match(args: List): MatchResult? { - val parameters = getParameters() + val instance = this.getInstance(args.toTypedArray()) ?: return null + val parameters = instance.parameters val mapping = arrayOfNulls(args.size) var exactInputTypes = 0 for (i in args.indices) { @@ -143,31 +152,34 @@ internal object FnResolver { } // check match val p = parameters[i] - val m = p.getMatch(a) when { - m == null -> return null // short-circuit - m == a -> exactInputTypes++ - else -> mapping[i] = coercion(a, m) + p == a -> exactInputTypes++ + else -> mapping[i] = coercion(a, p) ?: return null } } return MatchResult( - FnMatch.Static(this, mapping), + this, + mapping, exactInputTypes, ) } - private fun coercion(arg: PType, target: PType): Ref.Cast { - return Ref.Cast(arg.toCType(), target.toCType(), Ref.Cast.Safety.COERCION, true) + private fun coercion(arg: PType, target: PType): Ref.Cast? { + return when (CoercionFamily.canCoerce(arg, target)) { + true -> Ref.Cast(arg.toCType(), target.toCType(), Ref.Cast.Safety.COERCION, true) + false -> return null + } } private class MatchResult( - val match: FnMatch.Static, + val match: Function, + val mapping: Array, val numberOfExactInputTypes: Int, ) private object MatchResultComparator : Comparator { override fun compare(o1: MatchResult, o2: MatchResult): Int { - return FnComparator.reversed().compare(o1.match.function, o2.match.function) + return FnComparator.reversed().compare(o1.match, o2.match) } } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt index ccf470c48..ad5f69a85 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt @@ -442,12 +442,11 @@ internal data class Rex( } internal data class Static( - @JvmField internal val fn: Ref.Fn, + @JvmField internal val fn: Function.Instance, @JvmField internal val args: List, ) : Call() { public override val children: List by lazy { val kids = mutableListOf() - kids.add(fn) kids.addAll(args) kids.filterNotNull() } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index 088e41eb6..215de0db4 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -163,15 +163,14 @@ internal class PlanTransform(private val flags: Set) { } override fun visitRexOpCallDynamic(node: IRex.Op.Call.Dynamic, ctx: PType): Any { - val fns = node.candidates.map { it.fn.signature } - val args = node.args.map { visitRex(it, ctx) } // TODO assert on function name in plan typer .. here is not the place. + val args = node.args.map { visitRex(it, ctx) } + val fns = node.candidates.map { it.fn.signature } return factory.rexCallDynamic("unknown", fns, args) } override fun visitRexOpCallStatic(node: IRex.Op.Call.Static, ctx: PType): Any { - // TODO add argument types and move to PlanTyper!! - val fn = node.fn.signature.getInstance(emptyArray()) + val fn = node.fn val args = node.args.map { visitRex(it, ctx) } return factory.rexCall(fn, args) } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index ed96bf983..1bccfab9e 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -88,6 +88,7 @@ import org.partiql.spi.catalog.Identifier import org.partiql.types.PType import org.partiql.value.DecimalValue import org.partiql.value.MissingValue +import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.StringValue import org.partiql.value.boolValue @@ -146,6 +147,10 @@ internal object RexConverter { return rex(cType, op) } + private fun PartiQLValue.toPType(): PType { + return this.type.toPType() + } + /** * TODO PartiQLValue will be replaced by Datum (i.e. IonDatum) is a subsequent PR. */ diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 37985062c..364769c56 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -779,15 +779,12 @@ internal class PlanTyper(private val env: Env, config: Context) { else -> it } } - // TODO pass argument types to compute the return type. - val returnType = node.fn.signature.getReturnType(emptyArray()) + val instance = node.fn + val returnType: PType = instance.returns // Check if any arg is always missing val argIsAlwaysMissing = args.any { it.type.isMissingValue } - // TODO REMOVE ME !!! THIS IS A HACK (: - val instance = node.fn.signature.getInstance(emptyArray()) - if (argIsAlwaysMissing && instance.isMissingCall) { return errorRexAndReport(_listener, PErrors.alwaysMissing(null)) } @@ -804,12 +801,7 @@ internal class PlanTyper(private val env: Env, config: Context) { * @return */ override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: CompilerType?): Rex { - // TODO pass argument types to compute the return type - val types = node.candidates - .map { it.fn.signature.getReturnType(emptyArray()) } - .toMutableSet() - // TODO: Should this always be DYNAMIC? - return Rex(type = CompilerType(anyOf(types) ?: PType.dynamic()), op = node) + return Rex(type = CompilerType(PType.dynamic()), op = node) } override fun visitRexOpCase(node: Rex.Op.Case, ctx: CompilerType?): Rex { @@ -1151,7 +1143,7 @@ internal class PlanTyper(private val env: Env, config: Context) { if (firstBranchCondition !is Rex.Op.Call.Static) { return null } - if (!firstBranchCondition.fn.signature.getName().equals("is_struct", ignoreCase = true)) { + if (!firstBranchCondition.fn.name.equals("is_struct", ignoreCase = true)) { return null } val firstBranchResultType = firstBranch.rex.type diff --git a/partiql-planner/src/main/resources/partiql_plan_internal.ion b/partiql-planner/src/main/resources/partiql_plan_internal.ion index 3a6243ad7..efc9b4560 100644 --- a/partiql-planner/src/main/resources/partiql_plan_internal.ion +++ b/partiql-planner/src/main/resources/partiql_plan_internal.ion @@ -6,6 +6,7 @@ imports::{ partiql_value_type::'org.partiql.planner.internal.typer.CompilerType', static_type::'org.partiql.planner.internal.typer.CompilerType', fn_signature::'org.partiql.spi.function.Function', + fn_instance::'org.partiql.spi.function.Function.Instance', agg_signature::'org.partiql.spi.function.Aggregation', table::'org.partiql.spi.catalog.Table', ], @@ -127,7 +128,7 @@ rex::{ }, static::{ - fn: '.ref.fn', + fn: fn_instance, args: list::[rex], }, diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index 8bf4f4f8a..5116f8b32 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -3842,6 +3842,10 @@ internal class PlanTyperTestsPorted { when (val statement = plan.getOperation()) { is org.partiql.plan.Operation.Query -> { assert(collector.problems.isEmpty()) { + // Throw internal error for debugging + collector.problems.firstOrNull { it.code() == PError.INTERNAL_ERROR }?.let { pError -> + pError.getOrNull("CAUSE", Throwable::class.java)?.let { throw it } + } buildString { appendLine(collector.problems.toString()) appendLine() diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpArithmeticTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpArithmeticTest.kt index ea27ba1c4..427bfa8ac 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpArithmeticTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpArithmeticTest.kt @@ -37,6 +37,8 @@ class OpArithmeticTest : PartiQLTyperTestBase() { val arg1 = args[1] val output = when { arg0 == arg1 -> arg1 + // TODO arg0 == StaticType.DECIMAL && arg1 == StaticType.FLOAT -> arg1 // TODO: The cast table is wrong. Honestly, it should be deleted. + // TODO arg1 == StaticType.DECIMAL && arg0 == StaticType.FLOAT -> arg0 // TODO: The cast table is wrong castTablePType(arg1, arg0) == CastType.COERCION -> arg0 castTablePType(arg0, arg1) == CastType.COERCION -> arg1 else -> error("Arguments do not conform to parameters. Args: $args") diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt index a28dd9e2c..93ec9815e 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/operator/OpBitwiseAndTest.kt @@ -34,8 +34,8 @@ class OpBitwiseAndTest : PartiQLTyperTestBase() { val arg1 = args[1] val output = when { arg0 !in allIntPType && arg1 !in allIntPType -> PType.numeric() - arg0 in allIntPType && arg1 !in allIntPType -> arg0 - arg0 !in allIntPType && arg1 in allIntPType -> arg1 + arg0 in allIntPType && arg1 !in allIntPType -> PType.numeric() + arg0 !in allIntPType && arg1 in allIntPType -> PType.numeric() arg0 == arg1 -> arg1 castTablePType(arg1, arg0) == CastType.COERCION -> arg0 castTablePType(arg0, arg1) == CastType.COERCION -> arg1 diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt index 1b9d788db..049e5d2b0 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt @@ -17,6 +17,7 @@ import java.util.stream.Stream // TODO: Finalize the semantics for Between operator when operands contain MISSING // For now, Between propagates MISSING. class OpBetweenTest : PartiQLTyperTestBase() { + @TestFactory fun between(): Stream { val tests = listOf( diff --git a/partiql-spi/api/partiql-spi.api b/partiql-spi/api/partiql-spi.api index 410203dcb..2492caeaf 100644 --- a/partiql-spi/api/partiql-spi.api +++ b/partiql-spi/api/partiql-spi.api @@ -343,6 +343,7 @@ public final class org/partiql/spi/errors/PError : org/partiql/spi/Enum { public class org/partiql/spi/errors/PErrorException : org/partiql/spi/errors/PErrorListenerException { public field error Lorg/partiql/spi/errors/PError; public fun (Lorg/partiql/spi/errors/PError;)V + public fun (Lorg/partiql/spi/errors/PError;Ljava/lang/Throwable;)V public fun equals (Ljava/lang/Object;)Z public fun hashCode ()I public fun toString ()Ljava/lang/String; @@ -407,10 +408,13 @@ public final class org/partiql/spi/function/Aggregation$DefaultImpls { public abstract interface class org/partiql/spi/function/Function : org/partiql/spi/function/Routine { public static final field Companion Lorg/partiql/spi/function/Function$Companion; public abstract fun getInstance ([Lorg/partiql/types/PType;)Lorg/partiql/spi/function/Function$Instance; + public static fun instance (Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;)Lorg/partiql/spi/function/Function$Instance; public static fun static (Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;)Lorg/partiql/spi/function/Function; } public final class org/partiql/spi/function/Function$Companion { + public final fun instance (Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;)Lorg/partiql/spi/function/Function$Instance; + public static synthetic fun instance$default (Lorg/partiql/spi/function/Function$Companion;Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/spi/function/Function$Instance; public final fun static (Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;)Lorg/partiql/spi/function/Function; public static synthetic fun static$default (Lorg/partiql/spi/function/Function$Companion;Ljava/lang/String;[Lorg/partiql/spi/function/Parameter;Lorg/partiql/types/PType;ZZLkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lorg/partiql/spi/function/Function; } @@ -423,10 +427,11 @@ public final class org/partiql/spi/function/Function$DefaultImpls { public abstract class org/partiql/spi/function/Function$Instance { public final field isMissingCall Z public final field isNullCall Z + public final field name Ljava/lang/String; public final field parameters [Lorg/partiql/types/PType; public final field returns Lorg/partiql/types/PType; - public fun ([Lorg/partiql/types/PType;Lorg/partiql/types/PType;ZZ)V - public synthetic fun ([Lorg/partiql/types/PType;Lorg/partiql/types/PType;ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;[Lorg/partiql/types/PType;Lorg/partiql/types/PType;ZZ)V + public synthetic fun (Ljava/lang/String;[Lorg/partiql/types/PType;Lorg/partiql/types/PType;ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V public abstract fun invoke ([Lorg/partiql/spi/value/Datum;)Lorg/partiql/spi/value/Datum; } @@ -438,14 +443,14 @@ public final class org/partiql/spi/function/Parameter { public final fun getMatch (Lorg/partiql/types/PType;)Lorg/partiql/types/PType; public final fun getName ()Ljava/lang/String; public final fun getType ()Lorg/partiql/types/PType; - public static final fun numeric (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; + public static final fun number (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; public static final fun text (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; } public final class org/partiql/spi/function/Parameter$Companion { public final fun collection (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; public final fun dynamic (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; - public final fun numeric (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; + public final fun number (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; public final fun text (Ljava/lang/String;)Lorg/partiql/spi/function/Parameter; } diff --git a/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorException.java b/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorException.java index 91f1e0de5..1761b26b3 100644 --- a/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorException.java +++ b/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorException.java @@ -22,6 +22,16 @@ public PErrorException(@NotNull PError error) { this.error = error; } + /** + * Creates an exception that holds an error. + * @param error the error that is wrapped + * @param cause the cause of the error + */ + public PErrorException(@NotNull PError error, @NotNull Throwable cause) { + super(cause); + this.error = error; + } + @Override public String toString() { return "ErrorException{" + diff --git a/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorListener.java b/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorListener.java index 75c44958e..b25f35908 100644 --- a/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorListener.java +++ b/partiql-spi/src/main/java/org/partiql/spi/errors/PErrorListener.java @@ -26,6 +26,12 @@ public interface PErrorListener { static PErrorListener abortOnError() { return error -> { if (error.severity.code() == Severity.ERROR) { + if (error.code() == PError.INTERNAL_ERROR) { + Throwable cause = error.getOrNull("CAUSE", Throwable.class); + if (cause != null) { + throw new PErrorException(error, cause); + } + } throw new PErrorException(error); } }; diff --git a/partiql-spi/src/main/java/org/partiql/spi/value/DatumDecimal.java b/partiql-spi/src/main/java/org/partiql/spi/value/DatumDecimal.java index 01d22b4c6..63a7988ee 100644 --- a/partiql-spi/src/main/java/org/partiql/spi/value/DatumDecimal.java +++ b/partiql-spi/src/main/java/org/partiql/spi/value/DatumDecimal.java @@ -4,6 +4,7 @@ import org.partiql.types.PType; import java.math.BigDecimal; +import java.util.Objects; /** * This shall always be package-private (internal). @@ -35,4 +36,25 @@ public BigDecimal getBigDecimal() { public PType getType() { return _type; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof Datum)) return false; + Datum data = (Datum) o; + return Objects.equals(_type, data.getType()) && Objects.equals(_value, data.getBigDecimal()); + } + + @Override + public int hashCode() { + return Objects.hash(_value, _type); + } + + @Override + public String toString() { + return "DatumDecimal{" + + "_value=" + _value + + ", _type=" + _type + + '}'; + } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Builtins.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Builtins.kt index f3555089d..e56d6b6cf 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Builtins.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Builtins.kt @@ -44,12 +44,7 @@ internal object Builtins { Fn_BETWEEN__TIMESTAMP_TIMESTAMP_TIMESTAMP__BOOL, Fn_BIT_LENGTH__STRING__INT32, Fn_BIT_LENGTH__CLOB__INT32, - - Fn_BITWISE_AND__INT8_INT8__INT8, - Fn_BITWISE_AND__INT16_INT16__INT16, - Fn_BITWISE_AND__INT32_INT32__INT32, - Fn_BITWISE_AND__INT64_INT64__INT64, - Fn_BITWISE_AND__INT_INT__INT, + FnBitwiseAnd, Fn_CARDINALITY__BAG__INT32, Fn_CARDINALITY__LIST__INT32, @@ -136,7 +131,7 @@ internal object Builtins { Fn_DIVIDE__FLOAT32_FLOAT32__FLOAT32, Fn_DIVIDE__FLOAT64_FLOAT64__FLOAT64, Fn_DIVIDE__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, - Fn_EQ__ANY_ANY__BOOL, + FnEq, Fn_EXTRACT_DAY__DATE__INT32, Fn_EXTRACT_DAY__TIMESTAMP__INT32, Fn_EXTRACT_HOUR__TIME__INT32, @@ -254,14 +249,7 @@ internal object Builtins { Fn_LTE__DATE_DATE__BOOL, Fn_LTE__TIME_TIME__BOOL, Fn_LTE__TIMESTAMP_TIMESTAMP__BOOL, - Fn_MINUS__INT8_INT8__INT8, - Fn_MINUS__INT16_INT16__INT16, - Fn_MINUS__INT32_INT32__INT32, - Fn_MINUS__INT64_INT64__INT64, - Fn_MINUS__INT_INT__INT, - Fn_MINUS__FLOAT32_FLOAT32__FLOAT32, - Fn_MINUS__FLOAT64_FLOAT64__FLOAT64, - Fn_MINUS__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, + FnMinus, Fn_MODULO__INT8_INT8__INT8, Fn_MODULO__INT16_INT16__INT16, Fn_MODULO__INT32_INT32__INT32, @@ -282,15 +270,7 @@ internal object Builtins { Fn_OR__BOOL_BOOL__BOOL, Fn_OCTET_LENGTH__STRING__INT32, Fn_OCTET_LENGTH__CLOB__INT32, - - Fn_PLUS__INT8_INT8__INT8, - Fn_PLUS__INT16_INT16__INT16, - Fn_PLUS__INT32_INT32__INT32, - Fn_PLUS__INT64_INT64__INT64, - Fn_PLUS__INT_INT__INT, - Fn_PLUS__FLOAT32_FLOAT32__FLOAT32, - Fn_PLUS__FLOAT64_FLOAT64__FLOAT64, - Fn_PLUS__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY, + FnPlus, Fn_POS__INT8__INT8, Fn_POS__INT16__INT16, Fn_POS__INT32__INT32, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Function.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Function.kt index 2941429b3..64cc0a649 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Function.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Function.kt @@ -11,7 +11,7 @@ public interface Function : Routine { /** * Returns an invocable implementation. Optional. */ - public fun getInstance(args: Array): Instance { + public fun getInstance(args: Array): Instance? { throw Error("Function ${getName()} has no implementations.") } @@ -23,6 +23,7 @@ public interface Function : Routine { * @see Function.getInstance */ public abstract class Instance( + @JvmField public val name: String, @JvmField public val parameters: Array, @JvmField public val returns: PType, @JvmField public val isNullCall: Boolean = true, @@ -42,6 +43,37 @@ public interface Function : Routine { */ public companion object { + /** + * TODO INTERNALIZE TO SPI AND REPLACE WITH A BUILDER (OR SOMETHING..) + * + * @param name + * @param parameters + * @param returns + * @param isNullCall + * @param isMissingCall + * @param invoke + * @return + */ + @JvmStatic + public fun instance( + name: String, + parameters: Array, + returns: PType, + isNullCall: Boolean = true, + isMissingCall: Boolean = true, + invoke: (Array) -> Datum, + ): Instance { + return object : Instance( + name, + Array(parameters.size) { parameters[it].getType() }, + returns, + isNullCall, + isMissingCall, + ) { + override fun invoke(args: Array): Datum = invoke(args) + } + } + /** * TODO INTERNALIZE TO SPI AND REPLACE WITH A BUILDER (OR SOMETHING..) * @@ -64,6 +96,7 @@ public interface Function : Routine { ): Function = _Function( name, parameters, returns, object : Instance( + name, Array(parameters.size) { parameters[it].getType() }, returns, isNullCall, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Parameter.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Parameter.kt index 36710acbb..35cad050b 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/Parameter.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/Parameter.kt @@ -94,10 +94,10 @@ public class Parameter private constructor( public fun text(name: String): Parameter = Parameter(name, SqlTypeFamily.TEXT, false) /** - * Create a numeric [Parameter]. + * Create a number [Parameter]. */ @JvmStatic - public fun numeric(name: String): Parameter = Parameter(name, SqlTypeFamily.NUMERIC, false) + public fun number(name: String): Parameter = Parameter(name, SqlTypeFamily.NUMBER, false) /** * Create a collection [Parameter]. diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/ArithmeticDiadicOperator.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/ArithmeticDiadicOperator.kt new file mode 100644 index 000000000..d217897c5 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/ArithmeticDiadicOperator.kt @@ -0,0 +1,151 @@ +package org.partiql.spi.function.builtins + +import org.partiql.spi.function.Function +import org.partiql.spi.function.Parameter +import org.partiql.spi.function.builtins.TypePrecedence.TYPE_PRECEDENCE +import org.partiql.spi.internal.SqlTypeFamily +import org.partiql.spi.value.Datum +import org.partiql.types.PType + +/** + * This carries along with it a static table containing a mapping between the input types and the implementation. + */ +internal abstract class ArithmeticDiadicOperator : Function { + + companion object { + val allowed = SqlTypeFamily.NUMBER.members + setOf(PType.Kind.UNKNOWN) + } + + override fun getInstance(args: Array): Function.Instance? { + if (!allowed.contains(args[0].kind) || !allowed.contains(args[1].kind)) { + return null + } + val lhs = args[0] + val rhs = args[1] + val lhsPrecedence = TYPE_PRECEDENCE[lhs.kind] ?: throw IllegalArgumentException("Type not supported -- LHS = $lhs") + val rhsPrecedence = TYPE_PRECEDENCE[rhs.kind] ?: throw IllegalArgumentException("Type not supported -- RHS = $rhs") + val (newLhs, newRhs) = when (lhsPrecedence.compareTo(rhsPrecedence)) { + -1 -> (rhs to rhs) + 0 -> (lhs to rhs) + else -> (lhs to lhs) + } + val instance = instances[lhs.kind.ordinal][rhs.kind.ordinal] + return instance(newLhs, newRhs) + } + + /** + * @param integerLhs TODO + * @param integerRhs TODO + * @return TODO + */ + abstract fun getIntegerInstance(integerLhs: PType, integerRhs: PType): Function.Instance + + /** + * @param tinyIntLhs TODO + * @param tinyIntRhs TODO + * @return TODO + */ + abstract fun getTinyIntInstance(tinyIntLhs: PType, tinyIntRhs: PType): Function.Instance + + /** + * @param smallIntLhs TODO + * @param smallIntRhs TODO + * @return TODO + */ + abstract fun getSmallIntInstance(smallIntLhs: PType, smallIntRhs: PType): Function.Instance + + /** + * @param bigIntLhs TODO + * @param bigIntRhs TODO + * @return TODO + */ + abstract fun getBigIntInstance(bigIntLhs: PType, bigIntRhs: PType): Function.Instance + + /** + * TODO: This will soon be removed. + * @param numericLhs TODO + * @param numericRhs TODO + * @return TODO + */ + abstract fun getNumericInstance(numericLhs: PType, numericRhs: PType): Function.Instance + + /** + * @param decimalLhs TODO + * @param decimalRhs TODO + * @return TODO + */ + abstract fun getDecimalInstance(decimalLhs: PType, decimalRhs: PType): Function.Instance + + /** + * @param realLhs TODO + * @param realRhs TODO + * @return TODO + */ + abstract fun getRealInstance(realLhs: PType, realRhs: PType): Function.Instance + + /** + * @param doubleLhs TODO + * @param doubleRhs TODO + * @return TODO + */ + abstract fun getDoubleInstance(doubleLhs: PType, doubleRhs: PType): Function.Instance + + override fun getParameters(): Array { + return arrayOf( + Parameter.number("lhs"), + Parameter.number("rhs"), + ) + } + + override fun getReturnType(args: Array): PType { + return getInstance(args)?.returns ?: PType.dynamic() // TODO: Do we need this method? + } + + /** + * This is a lookup table for finding the appropriate instance for the given types. The table is + * initialized on construction using the get*Instance methods. + */ + private val instances: Array Function.Instance?>> = Array(PType.Kind.entries.size) { + Array(PType.Kind.entries.size) { + { _, _ -> null } + } + } + + private fun fillTable(lhs: PType.Kind, rhs: PType.Kind, instance: (PType, PType) -> Function.Instance) { + instances[lhs.ordinal][rhs.ordinal] = instance + } + + private fun fillTable(highPrecedence: PType.Kind, instance: (PType, PType) -> Function.Instance) { + val numbers = SqlTypeFamily.NUMBER.members + setOf(PType.Kind.UNKNOWN) + numbers.filter { + (TYPE_PRECEDENCE[highPrecedence]!! > TYPE_PRECEDENCE[it]!!) + }.forEach { + fillTable(highPrecedence, it) { lhs, _ -> instance(lhs, lhs) } + fillTable(it, highPrecedence) { _, rhs -> instance(rhs, rhs) } + } + fillTable(highPrecedence, highPrecedence) { lhs, rhs -> instance(lhs, rhs) } + } + + init { + fillTable(PType.Kind.TINYINT) { lhs, rhs -> getTinyIntInstance(lhs, rhs) } + fillTable(PType.Kind.SMALLINT) { lhs, rhs -> getSmallIntInstance(lhs, rhs) } + fillTable(PType.Kind.INTEGER) { lhs, rhs -> getIntegerInstance(lhs, rhs) } + fillTable(PType.Kind.BIGINT) { lhs, rhs -> getBigIntInstance(lhs, rhs) } + fillTable(PType.Kind.DECIMAL) { lhs, rhs -> getDecimalInstance(lhs, rhs) } + fillTable(PType.Kind.NUMERIC) { lhs, rhs -> getNumericInstance(lhs, rhs) } // TODO: Remove this + fillTable(PType.Kind.REAL) { lhs, rhs -> getRealInstance(lhs, rhs) } + fillTable(PType.Kind.DOUBLE) { lhs, rhs -> getDoubleInstance(lhs, rhs) } + } + + protected fun basic(arg: PType, invocation: (Array) -> Datum): Function.Instance { + return Function.instance( + name = getName(), + returns = arg, + parameters = arrayOf( + Parameter("lhs", arg), + Parameter("rhs", arg), + ), + invoke = invocation + ) + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnBitwiseAnd.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnBitwiseAnd.kt index 1357a1de4..6e2da54fa 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnBitwiseAnd.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnBitwiseAnd.kt @@ -4,82 +4,65 @@ package org.partiql.spi.function.builtins import org.partiql.spi.function.Function -import org.partiql.spi.function.Parameter import org.partiql.spi.value.Datum import org.partiql.types.PType import kotlin.experimental.and -internal val Fn_BITWISE_AND__INT8_INT8__INT8 = Function.static( - - name = "bitwise_and", - returns = PType.tinyint(), - parameters = arrayOf( - Parameter("lhs", PType.tinyint()), - Parameter("rhs", PType.tinyint()), - ), - -) { args -> - @Suppress("DEPRECATION") val arg0 = args[0].byte - @Suppress("DEPRECATION") val arg1 = args[1].byte - Datum.tinyint(arg0 and arg1) -} - -internal val Fn_BITWISE_AND__INT16_INT16__INT16 = Function.static( - - name = "bitwise_and", - returns = PType.smallint(), - parameters = arrayOf( - Parameter("lhs", PType.smallint()), - Parameter("rhs", PType.smallint()), - ), - -) { args -> - val arg0 = args[0].short - val arg1 = args[1].short - Datum.smallint(arg0 and arg1) -} - -internal val Fn_BITWISE_AND__INT32_INT32__INT32 = Function.static( - - name = "bitwise_and", - returns = PType.integer(), - parameters = arrayOf( - Parameter("lhs", PType.integer()), - Parameter("rhs", PType.integer()), - ), - -) { args -> - val arg0 = args[0].int - val arg1 = args[1].int - Datum.integer(arg0 and arg1) -} - -internal val Fn_BITWISE_AND__INT64_INT64__INT64 = Function.static( - - name = "bitwise_and", - returns = PType.bigint(), - parameters = arrayOf( - Parameter("lhs", PType.bigint()), - Parameter("rhs", PType.bigint()), - ), - -) { args -> - val arg0 = args[0].long - val arg1 = args[1].long - Datum.bigint(arg0 and arg1) -} - -internal val Fn_BITWISE_AND__INT_INT__INT = Function.static( - - name = "bitwise_and", - returns = PType.numeric(), - parameters = arrayOf( - @Suppress("DEPRECATION") Parameter("lhs", PType.numeric()), - @Suppress("DEPRECATION") Parameter("rhs", PType.numeric()), - ), - -) { args -> - val arg0 = args[0].bigInteger - val arg1 = args[1].bigInteger - Datum.numeric(arg0 and arg1) +internal object FnBitwiseAnd : ArithmeticDiadicOperator() { + override fun getName(): String { + return "bitwise_and" + } + + override fun getTinyIntInstance(tinyIntLhs: PType, tinyIntRhs: PType): Function.Instance { + return basic(PType.tinyint()) { args -> + @Suppress("DEPRECATION") val arg0 = args[0].byte + @Suppress("DEPRECATION") val arg1 = args[1].byte + Datum.tinyint(arg0 and arg1) + } + } + + override fun getSmallIntInstance(smallIntLhs: PType, smallIntRhs: PType): Function.Instance { + return basic(PType.smallint()) { args -> + val arg0 = args[0].short + val arg1 = args[1].short + Datum.smallint(arg0 and arg1) + } + } + + override fun getIntegerInstance(integerLhs: PType, integerRhs: PType): Function.Instance { + return basic(PType.integer()) { args -> + val arg0 = args[0].int + val arg1 = args[1].int + Datum.integer(arg0 and arg1) + } + } + + override fun getBigIntInstance(bigIntLhs: PType, bigIntRhs: PType): Function.Instance { + return basic(PType.bigint()) { args -> + val arg0 = args[0].long + val arg1 = args[1].long + Datum.bigint(arg0 and arg1) + } + } + + // TODO: Probably remove this if we don't expose NUMERIC + override fun getNumericInstance(numericLhs: PType, numericRhs: PType): Function.Instance { + return basic(PType.numeric()) { args -> + val arg0 = args[0].bigInteger + val arg1 = args[1].bigInteger + Datum.numeric(arg0 and arg1) + } + } + + override fun getDecimalInstance(decimalLhs: PType, decimalRhs: PType): Function.Instance { + return getNumericInstance(decimalLhs, decimalRhs) + } + + override fun getRealInstance(realLhs: PType, realRhs: PType): Function.Instance { + return getNumericInstance(realLhs, realRhs) + } + + override fun getDoubleInstance(doubleLhs: PType, doubleRhs: PType): Function.Instance { + return getNumericInstance(doubleLhs, doubleRhs) + } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnCollAgg.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnCollAgg.kt index e75f212cb..09a09f782 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnCollAgg.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnCollAgg.kt @@ -32,6 +32,7 @@ internal abstract class Fn_COLL_AGG__BAG__ANY( override fun getInstance(args: Array): Function.Instance = instance private val instance = object : Function.Instance( + name, parameters = arrayOf(PType.bag()), returns = PType.dynamic(), ) { diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnEq.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnEq.kt index c4da29280..33e78fdf6 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnEq.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnEq.kt @@ -8,9 +8,6 @@ import org.partiql.spi.function.Parameter import org.partiql.spi.value.Datum import org.partiql.types.PType -// Memoize the comparator. -private val comparator = Datum.comparator() - /** * According to SQL:1999: * > If either XV or YV is the null value, then `X Y` is unknown @@ -25,20 +22,41 @@ private val comparator = Datum.comparator() * TODO: The PartiQL Specification needs to clearly define the semantics of MISSING. That being said, this implementation * follows the existing conformance tests and SQL:1999. */ -internal val Fn_EQ__ANY_ANY__BOOL = Function.static( - name = "eq", - returns = PType.bool(), - parameters = arrayOf( - Parameter("lhs", PType.dynamic()), - Parameter("rhs", PType.dynamic()), - ), - isNullCall = true, - isMissingCall = false, -) { args -> - val lhs = args[0] - val rhs = args[1] - if (lhs.isMissing || rhs.isMissing) { - return@static Datum.nullValue(PType.bool()) +internal object FnEq : Function { + + // Memoize shared variables + private val comparator = Datum.comparator() + private val boolType = PType.bool() + private val nullValue = Datum.nullValue(boolType) + + override fun getName(): String { + return "eq" + } + + override fun getParameters(): Array { + return arrayOf(Parameter.dynamic("lhs"), Parameter.dynamic("rhs")) + } + + override fun getInstance(args: Array): Function.Instance { + return object : Function.Instance( + "eq", + args, + boolType, + isNullCall = false, + isMissingCall = false + ) { + override fun invoke(args: Array): Datum { + val lhs = args[0] + val rhs = args[1] + if (lhs.isMissing || rhs.isMissing) { + return nullValue + } + return Datum.bool(comparator.compare(lhs, rhs) == 0) + } + } + } + + override fun getReturnType(args: Array): PType { + return getInstance(args).returns } - Datum.bool(comparator.compare(lhs, rhs) == 0) } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsMissing.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsMissing.kt index 62bd9af45..7acb5188c 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsMissing.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsMissing.kt @@ -23,6 +23,7 @@ internal val Fn_IS_MISSING__ANY__BOOL = object : Function { * IS MISSING implementation. */ private var instance = object : Function.Instance( + name, parameters = arrayOf(PType.dynamic()), returns = PType.bool(), isNullCall = false, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsNull.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsNull.kt index 1d5c35bf4..d3d258db8 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsNull.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnIsNull.kt @@ -23,6 +23,7 @@ internal val Fn_IS_NULL__ANY__BOOL = object : Function { * IS NULL implementation. */ private var instance = object : Function.Instance( + name, parameters = arrayOf(PType.dynamic()), returns = PType.bool(), isNullCall = false, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnMinus.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnMinus.kt index 8a3981116..924066d4c 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnMinus.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnMinus.kt @@ -8,123 +8,88 @@ import org.partiql.spi.function.Parameter import org.partiql.spi.value.Datum import org.partiql.types.PType -// TODO: Handle Overflow -internal val Fn_MINUS__INT8_INT8__INT8 = Function.static( - - name = "minus", - returns = PType.tinyint(), - parameters = arrayOf( - Parameter("lhs", PType.tinyint()), - Parameter("rhs", PType.tinyint()), - ), - -) { args -> - @Suppress("DEPRECATION") val arg0 = args[0].byte - @Suppress("DEPRECATION") val arg1 = args[1].byte - Datum.tinyint((arg0 - arg1).toByte()) -} - -internal val Fn_MINUS__INT16_INT16__INT16 = Function.static( - - name = "minus", - returns = PType.smallint(), - parameters = arrayOf( - Parameter("lhs", PType.smallint()), - Parameter("rhs", PType.smallint()), - ), - -) { args -> - val arg0 = args[0].short - val arg1 = args[1].short - Datum.smallint((arg0 - arg1).toShort()) -} - -internal val Fn_MINUS__INT32_INT32__INT32 = Function.static( - - name = "minus", - returns = PType.integer(), - parameters = arrayOf( - Parameter("lhs", PType.integer()), - Parameter("rhs", PType.integer()), - ), - -) { args -> - val arg0 = args[0].int - val arg1 = args[1].int - Datum.integer((arg0 - arg1)) -} - -internal val Fn_MINUS__INT64_INT64__INT64 = Function.static( - - name = "minus", - returns = PType.bigint(), - parameters = arrayOf( - Parameter("lhs", PType.bigint()), - Parameter("rhs", PType.bigint()), - ), - -) { args -> - val arg0 = args[0].long - val arg1 = args[1].long - Datum.bigint((arg0 - arg1)) -} - -internal val Fn_MINUS__INT_INT__INT = Function.static( - - name = "minus", - returns = PType.numeric(), - parameters = arrayOf( - @Suppress("DEPRECATION") Parameter("lhs", PType.numeric()), - @Suppress("DEPRECATION") Parameter("rhs", PType.numeric()), - ), - -) { args -> - val arg0 = args[0].bigInteger - val arg1 = args[1].bigInteger - Datum.numeric((arg0 - arg1)) -} - -internal val Fn_MINUS__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Function.static( - - name = "minus", - returns = PType.decimal(), - parameters = arrayOf( - @Suppress("DEPRECATION") Parameter("lhs", PType.decimal()), - @Suppress("DEPRECATION") Parameter("rhs", PType.decimal()), - ), - -) { args -> - val arg0 = args[0].bigDecimal - val arg1 = args[1].bigDecimal - Datum.decimal(arg0 - arg1) -} - -internal val Fn_MINUS__FLOAT32_FLOAT32__FLOAT32 = Function.static( - - name = "minus", - returns = PType.real(), - parameters = arrayOf( - Parameter("lhs", PType.real()), - Parameter("rhs", PType.real()), - ), - -) { args -> - val arg0 = args[0].float - val arg1 = args[1].float - Datum.real(arg0 - arg1) -} - -internal val Fn_MINUS__FLOAT64_FLOAT64__FLOAT64 = Function.static( - - name = "minus", - returns = PType.doublePrecision(), - parameters = arrayOf( - Parameter("lhs", PType.doublePrecision()), - Parameter("rhs", PType.doublePrecision()), - ), - -) { args -> - val arg0 = args[0].double - val arg1 = args[1].double - Datum.doublePrecision(arg0 - arg1) +internal object FnMinus : ArithmeticDiadicOperator() { + + override fun getName(): String { + return "minus" + } + + override fun getIntegerInstance(integerLhs: PType, integerRhs: PType): Function.Instance { + return basic(PType.integer()) { args -> + val arg0 = args[0].int + val arg1 = args[1].int + Datum.integer(arg0 - arg1) + } + } + + override fun getTinyIntInstance(tinyIntLhs: PType, tinyIntRhs: PType): Function.Instance { + return basic(PType.tinyint()) { args -> + @Suppress("DEPRECATION") val arg0 = args[0].byte + @Suppress("DEPRECATION") val arg1 = args[1].byte + Datum.tinyint((arg0 - arg1).toByte()) + } + } + + override fun getSmallIntInstance(smallIntLhs: PType, smallIntRhs: PType): Function.Instance { + return basic(PType.smallint()) { args -> + val arg0 = args[0].short + val arg1 = args[1].short + Datum.smallint((arg0 - arg1).toShort()) + } + } + + override fun getBigIntInstance(bigIntLhs: PType, bigIntRhs: PType): Function.Instance { + return basic(PType.bigint()) { args -> + val arg0 = args[0].long + val arg1 = args[1].long + Datum.bigint((arg0 - arg1)) + } + } + + // TODO: Delete this + override fun getNumericInstance(numericLhs: PType, numericRhs: PType): Function.Instance { + return basic(PType.numeric()) { args -> + val arg0 = args[0].bigInteger + val arg1 = args[1].bigInteger + Datum.numeric(arg0 - arg1) + } + } + + /** + * Precision and scale calculation: + * P = max(s1, s2) + max(p1 - s1, p2 - s2) + 1 + * S = max(s1, s2) + */ + override fun getDecimalInstance(decimalLhs: PType, decimalRhs: PType): Function.Instance { + val p = Math.min(38, Math.max(decimalLhs.scale, decimalRhs.scale) + Math.max(decimalLhs.precision - decimalLhs.scale, decimalRhs.precision - decimalRhs.scale) + 1) + val s = Math.min(38, Math.max(decimalLhs.scale, decimalRhs.scale)) + return Function.instance( + name = "plus", + returns = PType.decimal(p, s), + parameters = arrayOf( + Parameter("lhs", decimalLhs), + Parameter("rhs", decimalRhs), + ) + ) { args -> + val arg0 = args[0].bigDecimal + val arg1 = args[1].bigDecimal + Datum.decimal(arg0 - arg1, p, s) + } + } + + override fun getRealInstance(realLhs: PType, realRhs: PType): Function.Instance { + return basic(PType.real()) { args -> + val arg0 = args[0].float + val arg1 = args[1].float + Datum.real((arg0 - arg1)) + } + } + + override fun getDoubleInstance(doubleLhs: PType, doubleRhs: PType): Function.Instance { + return basic(PType.doublePrecision()) { args -> + val arg0 = args[0].double + val arg1 = args[1].double + Datum.doublePrecision((arg0 - arg1)) + } + } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnNot.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnNot.kt index 0da34462f..c605f7800 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnNot.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnNot.kt @@ -17,6 +17,7 @@ internal val Fn_NOT__BOOL__BOOL = object : Function { private var returns = PType.bool() private var instance = object : Function.Instance( + name, parameters = arrayOf(PType.dynamic()), returns = PType.bool(), isNullCall = true, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnPlus.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnPlus.kt index 4b44fa83e..a427f4964 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnPlus.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/FnPlus.kt @@ -8,123 +8,87 @@ import org.partiql.spi.function.Parameter import org.partiql.spi.value.Datum import org.partiql.types.PType -// TODO: Handle Overflow -internal val Fn_PLUS__INT8_INT8__INT8 = Function.static( - - name = "plus", - returns = PType.tinyint(), - parameters = arrayOf( - Parameter("lhs", PType.tinyint()), - Parameter("rhs", PType.tinyint()), - ), - -) { args -> - @Suppress("DEPRECATION") val arg0 = args[0].byte - @Suppress("DEPRECATION") val arg1 = args[1].byte - Datum.tinyint((arg0 + arg1).toByte()) -} - -internal val Fn_PLUS__INT16_INT16__INT16 = Function.static( - - name = "plus", - returns = PType.smallint(), - parameters = arrayOf( - Parameter("lhs", PType.smallint()), - Parameter("rhs", PType.smallint()), - ), - -) { args -> - val arg0 = args[0].short - val arg1 = args[1].short - Datum.smallint((arg0 + arg1).toShort()) -} - -internal val Fn_PLUS__INT32_INT32__INT32 = Function.static( - - name = "plus", - returns = PType.integer(), - parameters = arrayOf( - Parameter("lhs", PType.integer()), - Parameter("rhs", PType.integer()), - ), - -) { args -> - val arg0 = args[0].int - val arg1 = args[1].int - Datum.integer(arg0 + arg1) -} - -internal val Fn_PLUS__INT64_INT64__INT64 = Function.static( - - name = "plus", - returns = PType.bigint(), - parameters = arrayOf( - Parameter("lhs", PType.bigint()), - Parameter("rhs", PType.bigint()), - ), - -) { args -> - val arg0 = args[0].long - val arg1 = args[1].long - Datum.bigint(arg0 + arg1) -} - -internal val Fn_PLUS__INT_INT__INT = Function.static( - - name = "plus", - returns = PType.numeric(), - parameters = arrayOf( - @Suppress("DEPRECATION") Parameter("lhs", PType.numeric()), - @Suppress("DEPRECATION") Parameter("rhs", PType.numeric()), - ), - -) { args -> - val arg0 = args[0].bigInteger - val arg1 = args[1].bigInteger - Datum.numeric(arg0 + arg1) -} - -internal val Fn_PLUS__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Function.static( - - name = "plus", - returns = PType.decimal(), - parameters = arrayOf( - Parameter("lhs", PType.decimal()), - Parameter("rhs", PType.decimal()), - ), - -) { args -> - val arg0 = args[0].bigDecimal - val arg1 = args[1].bigDecimal - Datum.decimal(arg0 + arg1) -} - -internal val Fn_PLUS__FLOAT32_FLOAT32__FLOAT32 = Function.static( - - name = "plus", - returns = PType.real(), - parameters = arrayOf( - Parameter("lhs", PType.real()), - Parameter("rhs", PType.real()), - ), - -) { args -> - val arg0 = args[0].float - val arg1 = args[1].float - Datum.real(arg0 + arg1) -} - -internal val Fn_PLUS__FLOAT64_FLOAT64__FLOAT64 = Function.static( - - name = "plus", - returns = PType.doublePrecision(), - parameters = arrayOf( - Parameter("lhs", PType.doublePrecision()), - Parameter("rhs", PType.doublePrecision()), - ), - -) { args -> - val arg0 = args[0].double - val arg1 = args[1].double - Datum.doublePrecision(arg0 + arg1) +internal object FnPlus : ArithmeticDiadicOperator() { + override fun getName(): String { + return "plus" + } + + override fun getTinyIntInstance(tinyIntLhs: PType, tinyIntRhs: PType): Function.Instance { + return basic(PType.tinyint()) { args -> + @Suppress("DEPRECATION") val arg0 = args[0].byte + @Suppress("DEPRECATION") val arg1 = args[1].byte + Datum.tinyint((arg0 + arg1).toByte()) + } + } + + override fun getSmallIntInstance(smallIntLhs: PType, smallIntRhs: PType): Function.Instance { + return basic(PType.smallint()) { args -> + val arg0 = args[0].short + val arg1 = args[1].short + Datum.smallint((arg0 + arg1).toShort()) + } + } + + override fun getIntegerInstance(integerLhs: PType, integerRhs: PType): Function.Instance { + return basic(PType.integer()) { args -> + val arg0 = args[0].int + val arg1 = args[1].int + Datum.integer(arg0 + arg1) + } + } + + override fun getBigIntInstance(bigIntLhs: PType, bigIntRhs: PType): Function.Instance { + return basic(PType.bigint()) { args -> + val arg0 = args[0].long + val arg1 = args[1].long + Datum.bigint(arg0 + arg1) + } + } + + // TODO: Probably remove this if we don't expose NUMERIC + override fun getNumericInstance(numericLhs: PType, numericRhs: PType): Function.Instance { + return basic(PType.numeric()) { args -> + val arg0 = args[0].bigInteger + val arg1 = args[1].bigInteger + Datum.numeric(arg0 + arg1) + } + } + + /** + * Precision and scale calculation: + * P = max(s1, s2) + max(p1 - s1, p2 - s2) + 1 + * S = max(s1, s2) + */ + override fun getDecimalInstance(decimalLhs: PType, decimalRhs: PType): Function.Instance { + val p = Math.min(38, Math.max(decimalLhs.scale, decimalRhs.scale) + Math.max(decimalLhs.precision - decimalLhs.scale, decimalRhs.precision - decimalRhs.scale) + 1) + val s = Math.min(38, Math.max(decimalLhs.scale, decimalRhs.scale)) + return Function.instance( + name = "plus", + returns = PType.decimal(p, s), + parameters = arrayOf( + Parameter("lhs", decimalLhs), + Parameter("rhs", decimalRhs), + ) + ) { args -> + val arg0 = args[0].bigDecimal + val arg1 = args[1].bigDecimal + Datum.decimal(arg0 + arg1, p, s) + } + } + + override fun getRealInstance(realLhs: PType, realRhs: PType): Function.Instance { + return basic(PType.real()) { args -> + val arg0 = args[0].float + val arg1 = args[1].float + Datum.real(arg0 + arg1) + } + } + + override fun getDoubleInstance(doubleLhs: PType, doubleRhs: PType): Function.Instance { + return basic(PType.doublePrecision()) { args -> + val arg0 = args[0].double + val arg1 = args[1].double + Datum.doublePrecision(arg0 + arg1) + } + } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/TypePrecedence.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/TypePrecedence.kt new file mode 100644 index 000000000..fdb4b03a2 --- /dev/null +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/builtins/TypePrecedence.kt @@ -0,0 +1,38 @@ +package org.partiql.spi.function.builtins + +import org.partiql.types.PType.Kind + +internal object TypePrecedence { + + /** + * @return the precedence of the types for the PartiQL comparator. + * @see .TYPE_PRECEDENCE + */ + internal val TYPE_PRECEDENCE: Map = listOf( + Kind.UNKNOWN, + Kind.BOOL, + Kind.TINYINT, + Kind.SMALLINT, + Kind.INTEGER, + Kind.BIGINT, + Kind.NUMERIC, + Kind.DECIMAL, + Kind.REAL, + Kind.DOUBLE, + Kind.CHAR, + Kind.VARCHAR, + Kind.STRING, + Kind.CLOB, + Kind.BLOB, + Kind.DATE, + Kind.TIME, + Kind.TIMEZ, + Kind.TIMESTAMP, + Kind.TIMESTAMPZ, + Kind.ARRAY, + Kind.BAG, + Kind.ROW, + Kind.STRUCT, + Kind.DYNAMIC + ).mapIndexed { precedence, type -> type to precedence }.toMap() +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypeFamily.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypeFamily.kt index bf07ea213..1cd6ae1cf 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypeFamily.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypeFamily.kt @@ -59,7 +59,7 @@ internal class SqlTypeFamily private constructor( ) @JvmStatic - val NUMERIC = SqlTypeFamily( + val NUMBER = SqlTypeFamily( preferred = PType.decimal(), members = setOf( Kind.TINYINT, diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypes.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypes.kt index 68818c85d..550c751f6 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypes.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/internal/SqlTypes.kt @@ -119,7 +119,7 @@ internal object SqlTypes { * ``` */ private fun areAssignableNumberTypes(input: PType, target: PType): Boolean { - return input in SqlTypeFamily.NUMERIC && target in SqlTypeFamily.NUMERIC + return input in SqlTypeFamily.NUMBER && target in SqlTypeFamily.NUMBER } /** diff --git a/partiql-types/src/main/java/org/partiql/types/PType.java b/partiql-types/src/main/java/org/partiql/types/PType.java index c6bfb59b6..094e3c769 100644 --- a/partiql-types/src/main/java/org/partiql/types/PType.java +++ b/partiql-types/src/main/java/org/partiql/types/PType.java @@ -482,11 +482,12 @@ static PType numeric() { } /** - * @return a decimal with the default precision (38) and default scale (0) + * @return a PartiQL decimal type + * @deprecated this API is experimental and is subject to modification/deletion without prior notice. */ @NotNull static PType decimal() { - return decimal(38, 0); + return new PTypeDecimal(38, 0); } /**