Skip to content

Commit

Permalink
Enable min & max to work on all the data types and across data ty… (
Browse files Browse the repository at this point in the history
  • Loading branch information
lziq authored May 13, 2022
1 parent 0bea845 commit b633476
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 42 deletions.
51 changes: 30 additions & 21 deletions lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ import org.partiql.lang.types.toTypedOpParameter
import org.partiql.lang.util.bigDecimalOf
import org.partiql.lang.util.checkThreadInterrupted
import org.partiql.lang.util.codePointSequence
import org.partiql.lang.util.compareTo
import org.partiql.lang.util.div
import org.partiql.lang.util.drop
import org.partiql.lang.util.foldLeftProduct
Expand Down Expand Up @@ -180,8 +179,8 @@ internal class EvaluatingCompiler(
* Base class for [ExprAggregator] instances which accumulate values and perform a final computation.
*/
private inner class Accumulator(
var current: Number? = 0L,
val nextFunc: (Number?, ExprValue) -> Number,
var current: ExprValue?,
val nextFunc: (ExprValue?, ExprValue) -> ExprValue,
val valueFilter: (ExprValue) -> Boolean = { _ -> true }
) : ExprAggregator {

Expand All @@ -192,37 +191,45 @@ internal class EvaluatingCompiler(
}
}

override fun compute() = current?.exprValue() ?: valueFactory.nullValue
override fun compute() = current ?: valueFactory.nullValue
}

private fun comparisonAccumulator(cmpFunc: (Number, Number) -> Boolean): (Number?, ExprValue) -> Number =
{ curr, next ->
val nextNum = next.numberValue()
when (curr) {
null -> nextNum
else -> when {
cmpFunc(nextNum, curr) -> nextNum
else -> curr
}
private fun comparisonAccumulator(comparator: NaturalExprValueComparators): (ExprValue?, ExprValue) -> ExprValue =
{ left, right ->
when {
left == null || comparator.compare(left, right) > 0 -> right
else -> left
}
}

/** Dispatch table for built-in aggregate functions. */
private val builtinAggregates: Map<Pair<String, PartiqlAst.SetQuantifier>, ExprAggregatorFactory> =
run {
val countAccFunc: (Number?, ExprValue) -> Number = { curr, _ -> curr!! + 1L }
val sumAccFunc: (Number?, ExprValue) -> Number = { curr, next ->
curr?.let { it + next.numberValue() } ?: next.numberValue()
fun checkIsNumberType(funcName: String, value: ExprValue) {
if (!value.type.isNumber) {
errNoContext(
message = "Aggregate function $funcName expects arguments of NUMBER type but the following value was provided: ${value.ionValue}, with type of ${value.type}",
errorCode = ErrorCode.EVALUATOR_INVALID_ARGUMENTS_FOR_AGG_FUNCTION,
internal = false
)
}
}

val countAccFunc: (ExprValue?, ExprValue) -> ExprValue = { accumulated, _ -> (accumulated!!.longValue() + 1L).exprValue() }
val sumAccFunc: (ExprValue?, ExprValue) -> ExprValue = { accumulated, nextItem ->
checkIsNumberType("SUM", nextItem)
accumulated?.let { (it.numberValue() + nextItem.numberValue()).exprValue() } ?: nextItem
}
val minAccFunc = comparisonAccumulator { left, right -> left < right }
val maxAccFunc = comparisonAccumulator { left, right -> left > right }
val minAccFunc = comparisonAccumulator(NaturalExprValueComparators.NULLS_LAST_ASC)
val maxAccFunc = comparisonAccumulator(NaturalExprValueComparators.NULLS_LAST_DESC)
val avgAggregateGenerator = { filter: (ExprValue) -> Boolean ->
object : ExprAggregator {
var sum: Number? = null
var count = 0L

override fun next(value: ExprValue) {
if (value.isNotUnknown() && filter.invoke(value)) {
checkIsNumberType("AVG", value)
sum = sum?.let { it + value.numberValue() } ?: value.numberValue()
count++
}
Expand All @@ -237,11 +244,11 @@ internal class EvaluatingCompiler(
// each distinct ExprAggregator must get its own createUniqueExprValueFilter()
mapOf(
Pair("count", PartiqlAst.SetQuantifier.All()) to ExprAggregatorFactory.over {
Accumulator(0L, countAccFunc, allFilter)
Accumulator((0L).exprValue(), countAccFunc, allFilter)
},

Pair("count", PartiqlAst.SetQuantifier.Distinct()) to ExprAggregatorFactory.over {
Accumulator(0L, countAccFunc, createUniqueExprValueFilter())
Accumulator((0L).exprValue(), countAccFunc, createUniqueExprValueFilter())
},

Pair("sum", PartiqlAst.SetQuantifier.All()) to ExprAggregatorFactory.over {
Expand Down Expand Up @@ -2189,7 +2196,9 @@ internal class EvaluatingCompiler(
thunkFactory.thunkEnv(metas) { env ->
val aggregator = aggFactory.create()
val argValue = argThunk(env)
argValue.forEach { aggregator.next(it) }
argValue.forEach {
aggregator.next(it)
}
aggregator.compute()
}
ExpressionContext.SELECT_LIST -> {
Expand Down
4 changes: 2 additions & 2 deletions lang/src/org/partiql/lang/eval/ExprValueExtensions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ fun ExprValue.timestampValue(): Timestamp =
scalar.timestampValue() ?: errNoContext("Expected timestamp: $ionValue", errorCode = ErrorCode.EVALUATOR_UNEXPECTED_VALUE_TYPE, internal = false)

fun ExprValue.stringValue(): String =
scalar.stringValue() ?: errNoContext("Expected text: $ionValue", errorCode = ErrorCode.EVALUATOR_UNEXPECTED_VALUE_TYPE, internal = false)
scalar.stringValue() ?: errNoContext("Expected string: $ionValue", errorCode = ErrorCode.EVALUATOR_UNEXPECTED_VALUE_TYPE, internal = false)

fun ExprValue.bytesValue(): ByteArray =
scalar.bytesValue() ?: errNoContext("Expected LOB: $ionValue", errorCode = ErrorCode.EVALUATOR_UNEXPECTED_VALUE_TYPE, internal = false)
scalar.bytesValue() ?: errNoContext("Expected boolean: $ionValue", errorCode = ErrorCode.EVALUATOR_UNEXPECTED_VALUE_TYPE, internal = false)

internal fun ExprValue.dateTimePartValue(): DateTimePart =
try {
Expand Down
19 changes: 0 additions & 19 deletions lang/test/org/partiql/lang/eval/SimpleEvaluatingCompilerTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,4 @@ class SimpleEvaluatingCompilerTests : EvaluatorTestBase() {
propertyValueMapOf(1, 1, Property.CAST_FROM to "SYMBOL", Property.CAST_TO to "INT"),
expectedPermissiveModeResult = "MISSING"
)

@Test
fun sum() {
runEvaluatorTestCase("SUM(`[1, 2, 3]`)", expectedResult = "6")
runEvaluatorTestCase("SUM(`[1, 2e0, 3e0]`)", expectedResult = "6e0")
runEvaluatorTestCase("SUM(`[1, 2d0, 3d0]`)", expectedResult = "6d0")
runEvaluatorTestCase("SUM(`[1, 2e0, 3d0]`)", expectedResult = "6d0")
runEvaluatorTestCase("SUM(`[1, 2d0, 3e0]`)", expectedResult = "6d0")
}

@Test
fun max() {
runEvaluatorTestCase("max(`[1, 2, 3]`)", expectedResult = "3")
runEvaluatorTestCase("max(`[1, 2.0, 3]`)", expectedResult = "3")
runEvaluatorTestCase("max(`[1, 2e0, 3e0]`)", expectedResult = "3e0")
runEvaluatorTestCase("max(`[1, 2d0, 3d0]`)", expectedResult = "3d0")
runEvaluatorTestCase("max(`[1, 2e0, 3d0]`)", expectedResult = "3d0")
runEvaluatorTestCase("max(`[1, 2d0, 3e0]`)", expectedResult = "3e0")
}
}
42 changes: 42 additions & 0 deletions lang/test/org/partiql/lang/eval/builtins/aggfunctions/AvgTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package org.partiql.lang.eval.builtins.aggfunctions

import org.junit.Test
import org.partiql.lang.errors.ErrorCode
import org.partiql.lang.eval.EvaluatorTestBase

class AvgTests : EvaluatorTestBase() {
@Test
fun avgNull() = runEvaluatorTestCase("AVG([null, null])", expectedResult = "null")

@Test
fun avgMissing() = runEvaluatorTestCase("AVG([missing, missing])", expectedResult = "null")

@Test
fun avgInt() = runEvaluatorTestCase("AVG(`[1, 2, 3]`)", expectedResult = "2.")

@Test
fun avgMixed0() = runEvaluatorTestCase("AVG(`[1, 2e0, 3e0]`)", expectedResult = "2.")

@Test
fun avgMixed1() = runEvaluatorTestCase("AVG(`[1, 2d0, 3d0]`)", expectedResult = "2.")

@Test
fun avgMixed2() = runEvaluatorTestCase("AVG(`[1, 2e0, 3d0]`)", expectedResult = "2.")

@Test
fun avgMixed3() = runEvaluatorTestCase("AVG(`[1, 2d0, 3e0]`)", expectedResult = "2.")

@Test
fun avgOverflow() = runEvaluatorErrorTestCase(
"AVG([1, 9223372036854775807])",
ErrorCode.EVALUATOR_INTEGER_OVERFLOW,
expectedPermissiveModeResult = "MISSING"
)

@Test
fun avgUnderflow() = runEvaluatorErrorTestCase(
"AVG([-1, -9223372036854775808])",
ErrorCode.EVALUATOR_INTEGER_OVERFLOW,
expectedPermissiveModeResult = "MISSING"
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.partiql.lang.eval.builtins.aggfunctions

import org.junit.Test
import org.partiql.lang.eval.EvaluatorTestBase

class CountTests : EvaluatorTestBase() {
@Test
fun countEmpty() = runEvaluatorTestCase("COUNT(`[]`)", expectedResult = "0")

@Test
fun countNull() = runEvaluatorTestCase("COUNT([null, null])", expectedResult = "0")

@Test
fun countMissing() = runEvaluatorTestCase("COUNT([missing])", expectedResult = "0")

@Test
fun countBoolean() = runEvaluatorTestCase("COUNT(`[true, false]`)", expectedResult = "2")

@Test
fun countInt() = runEvaluatorTestCase("COUNT(`[1, 2, 3]`)", expectedResult = "3")

@Test
fun countDecimal() = runEvaluatorTestCase("COUNT(`[1e0, 2e0, 3e0]`)", expectedResult = "3")

@Test
fun countFloat() = runEvaluatorTestCase("COUNT(`[1e0, 2e0, 3e0]`)", expectedResult = "3")

@Test
fun countString() = runEvaluatorTestCase("COUNT(`[\"1\", \"2\", \"3\"]`)", expectedResult = "3")

@Test
fun countTimestamp() = runEvaluatorTestCase("COUNT(`[2020-01-01T00:00:00Z, 2020-01-01T00:00:01Z]`)", expectedResult = "2")

@Test
fun countBlob() = runEvaluatorTestCase("COUNT(`[{{ aaaa }}, {{ aaab }}]`)", expectedResult = "2")

@Test
fun countClob() = runEvaluatorTestCase("COUNT(`[{{ \"aaaa\" }}, {{ \"aaab\" }}]`)", expectedResult = "2")

@Test
fun countSexp() = runEvaluatorTestCase("COUNT(`[(1), (2)]`)", expectedResult = "2")

@Test
fun countList() = runEvaluatorTestCase("COUNT(`[[1], [2]]`)", expectedResult = "2")

@Test
fun countBag() = runEvaluatorTestCase("COUNT([<<1>>, <<2>>])", expectedResult = "2")

@Test
fun countStruct() = runEvaluatorTestCase("COUNT(`[{'a':1}, {'a':2}]`)", expectedResult = "2")

@Test
fun countMixed0() = runEvaluatorTestCase("COUNT([null, missing, 1, 2])", expectedResult = "2")

@Test
fun countMixed1() = runEvaluatorTestCase("COUNT([1, '2', true, `2020-01-01T00:00:00Z`, `{{ aaaa }}`])", expectedResult = "5")
}
78 changes: 78 additions & 0 deletions lang/test/org/partiql/lang/eval/builtins/aggfunctions/MaxTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package org.partiql.lang.eval.builtins.aggfunctions

import org.junit.Test
import org.partiql.lang.eval.EvaluatorTestBase

class MaxTests : EvaluatorTestBase() {
@Test
fun maxNull() = runEvaluatorTestCase("max([null, null])", expectedResult = "null")

@Test
fun maxMissing() = runEvaluatorTestCase("max([missing, missing])", expectedResult = "null")

@Test
fun maxNumber0() = runEvaluatorTestCase("max(`[1, 2, 3]`)", expectedResult = "3")

@Test
fun maxNumber1() = runEvaluatorTestCase("max(`[1, 2.0, 3]`)", expectedResult = "3")

@Test
fun maxNumber2() = runEvaluatorTestCase("max(`[1, 2e0, 3e0]`)", expectedResult = "3e0")

@Test
fun maxNumber3() = runEvaluatorTestCase("max(`[1, 2d0, 3d0]`)", expectedResult = "3d0")

@Test
fun maxNumber4() = runEvaluatorTestCase("max(`[1, 2e0, 3d0]`)", expectedResult = "3d0")

@Test
fun maxNumber5() = runEvaluatorTestCase("max(`[1, 2d0, 3e0]`)", expectedResult = "3e0")

@Test
fun maxString0() = runEvaluatorTestCase("max(['a', 'abc', '3'])", expectedResult = "\"abc\"")

@Test
fun maxString1() = runEvaluatorTestCase("max(['1', '2', '3', null])", expectedResult = "\"3\"")

@Test
fun maxTimestamp0() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`])", expectedResult = "2020-01-01T00:00:02Z")

@Test
fun maxTimestamp1() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-01T00:01:00Z`, `2020-01-01T00:02:00Z`])", expectedResult = "2020-01-01T00:02:00Z")

@Test
fun maxTimestamp2() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-01T01:00:00Z`, `2020-01-01T02:00:00Z`])", expectedResult = "2020-01-01T02:00:00Z")

@Test
fun maxTimestamp3() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-02T00:00:00Z`, `2020-01-03T00:00:00Z`])", expectedResult = "2020-01-03T00:00:00Z")

@Test
fun maxTimestamp4() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-02-01T00:00:00Z`, `2020-03-01T00:00:00Z`])", expectedResult = "2020-03-01T00:00:00Z")

@Test
fun maxTimestamp5() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2021-01-01T00:00:00Z`, `2022-01-01T00:00:00Z`])", expectedResult = "2022-01-01T00:00:00Z")

@Test
fun maxTimestamp6() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, `2020-01-01T00:00:01Z`, `2020-01-01T00:00:02Z`, null])", expectedResult = "2020-01-01T00:00:02Z")

@Test
fun maxBoolean() = runEvaluatorTestCase("max([false, true])", expectedResult = "true")

@Test
fun maxBlob() = runEvaluatorTestCase("max([`{{ aaaa }}`, `{{ aaab }}`])", expectedResult = "{{aaab}}")

@Test
fun maxClob() = runEvaluatorTestCase("max([`{{\"a\"}}`, `{{\"b\"}}`])", expectedResult = "{{\"b\"}}")

@Test
fun maxMixed0() = runEvaluatorTestCase("max([false, 1])", expectedResult = "1")

@Test
fun maxMixed1() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, 1])", expectedResult = "2020-01-01T00:00:00Z")

@Test
fun maxMixed2() = runEvaluatorTestCase("max([`2020-01-01T00:00:00Z`, '1'])", expectedResult = "\"1\"")

@Test
fun maxMixed3() = runEvaluatorTestCase("max([`{{\"abcd\"}}`, '1'])", expectedResult = "{{\"abcd\"}}")
}
Loading

0 comments on commit b633476

Please sign in to comment.