diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt index fda25eeac..d66fe18eb 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt @@ -254,19 +254,19 @@ public object PartiQLHeader : Header() { ), ) - private fun lt(): List = types.numeric.map { t -> + private fun lt(): List = (types.numeric + types.text + BOOL).map { t -> binary("lt", BOOL, t, t) } - private fun lte(): List = types.numeric.map { t -> + private fun lte(): List = (types.numeric + types.text + BOOL).map { t -> binary("lte", BOOL, t, t) } - private fun gt(): List = types.numeric.map { t -> + private fun gt(): List = (types.numeric + types.text + BOOL).map { t -> binary("gt", BOOL, t, t) } - private fun gte(): List = types.numeric.map { t -> + private fun gte(): List = (types.numeric + types.text + BOOL).map { t -> binary("gte", BOOL, t, t) } @@ -347,7 +347,7 @@ public object PartiQLHeader : Header() { ) } - private fun between(): List = types.numeric.map { t -> + private fun between(): List = (types.numeric + types.text).map { t -> FunctionSignature.Scalar( name = "between", returns = BOOL, diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt index 59844fa4d..4167751ab 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpBetweenTest.kt @@ -24,6 +24,10 @@ class OpBetweenTest : PartiQLTyperTestBase() { allNumberType + listOf(StaticType.NULL), allNumberType + listOf(StaticType.NULL), allNumberType + listOf(StaticType.NULL), + ) + cartesianProduct( + StaticType.TEXT.allTypes + listOf(StaticType.CLOB, StaticType.NULL), + StaticType.TEXT.allTypes + listOf(StaticType.CLOB, StaticType.NULL), + StaticType.TEXT.allTypes + listOf(StaticType.CLOB, StaticType.NULL) ) val failureArgs = cartesianProduct( diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpComparisonTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpComparisonTest.kt index 5ad77f979..69c34eed7 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpComparisonTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/predicate/OpComparisonTest.kt @@ -3,10 +3,8 @@ package org.partiql.planner.internal.typer.predicate import org.junit.jupiter.api.DynamicContainer import org.junit.jupiter.api.TestFactory import org.partiql.planner.internal.typer.PartiQLTyperTestBase -import org.partiql.planner.util.CastType import org.partiql.planner.util.allSupportedType import org.partiql.planner.util.cartesianProduct -import org.partiql.planner.util.castTable import org.partiql.types.StaticType import java.util.stream.Stream @@ -59,6 +57,12 @@ class OpComparisonTest : PartiQLTyperTestBase() { cartesianProduct( StaticType.NUMERIC.allTypes + listOf(StaticType.NULL), StaticType.NUMERIC.allTypes + listOf(StaticType.NULL) + ) + cartesianProduct( + StaticType.TEXT.allTypes + listOf(StaticType.CLOB, StaticType.NULL), + StaticType.TEXT.allTypes + listOf(StaticType.CLOB, StaticType.NULL) + ) + cartesianProduct( + listOf(StaticType.BOOL, StaticType.NULL), + listOf(StaticType.BOOL, StaticType.NULL) ) val failureArgs = cartesianProduct( allSupportedType, @@ -68,24 +72,10 @@ class OpComparisonTest : PartiQLTyperTestBase() { }.toSet() successArgs.forEach { args: List -> - val arg0 = args.first() - val arg1 = args[1] - if (args.contains(StaticType.MISSING)) { - (this[TestResult.Success(StaticType.MISSING)] ?: setOf(args)).let { - put(TestResult.Success(StaticType.MISSING), it + setOf(args)) - } - } else if (args.contains(StaticType.NULL)) { + if (args.contains(StaticType.NULL)) { (this[TestResult.Success(StaticType.NULL)] ?: setOf(args)).let { put(TestResult.Success(StaticType.NULL), it + setOf(args)) } - } else if (arg0 == arg1) { - (this[TestResult.Success(StaticType.BOOL)] ?: setOf(args)).let { - put(TestResult.Success(StaticType.BOOL), it + setOf(args)) - } - } else if (castTable(arg1, arg0) == CastType.COERCION) { - (this[TestResult.Success(StaticType.BOOL)] ?: setOf(args)).let { - put(TestResult.Success(StaticType.BOOL), it + setOf(args)) - } } else { (this[TestResult.Success(StaticType.BOOL)] ?: setOf(args)).let { put(TestResult.Success(StaticType.BOOL), it + setOf(args))