Skip to content

Commit

Permalink
Aggregations without GROUP BY keys
Browse files Browse the repository at this point in the history
  • Loading branch information
RCHowell committed Oct 11, 2023
1 parent ed50800 commit 5d422da
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 86 deletions.
54 changes: 46 additions & 8 deletions partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,13 @@ internal class Header(
returns = BOOL,
parameters = listOf(FunctionParameter("value", BOOL)),
isNullable = true,
)
),
FunctionSignature.Aggregation(
name = "every",
returns = BOOL,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = true,
),
)

private fun any() = listOf(
Expand All @@ -724,7 +730,13 @@ internal class Header(
returns = BOOL,
parameters = listOf(FunctionParameter("value", BOOL)),
isNullable = true,
)
),
FunctionSignature.Aggregation(
name = "any",
returns = BOOL,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = true,
),
)

private fun some() = listOf(
Expand All @@ -733,7 +745,13 @@ internal class Header(
returns = BOOL,
parameters = listOf(FunctionParameter("value", BOOL)),
isNullable = true,
)
),
FunctionSignature.Aggregation(
name = "some",
returns = BOOL,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = true,
),
)

private fun count() = listOf(
Expand All @@ -742,7 +760,7 @@ internal class Header(
returns = INT,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = false,
)
),
)

private fun min() = numericTypes.map {
Expand All @@ -752,7 +770,12 @@ internal class Header(
parameters = listOf(FunctionParameter("value", it)),
isNullable = true,
)
}
} + FunctionSignature.Aggregation(
name = "min",
returns = ANY,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = true,
)

private fun max() = numericTypes.map {
FunctionSignature.Aggregation(
Expand All @@ -761,7 +784,12 @@ internal class Header(
parameters = listOf(FunctionParameter("value", it)),
isNullable = true,
)
}
} + FunctionSignature.Aggregation(
name = "max",
returns = ANY,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = true,
)

private fun sum() = numericTypes.map {
FunctionSignature.Aggregation(
Expand All @@ -770,7 +798,12 @@ internal class Header(
parameters = listOf(FunctionParameter("value", it)),
isNullable = true,
)
}
} + FunctionSignature.Aggregation(
name = "sum",
returns = ANY,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = true,
)

private fun avg() = numericTypes.map {
FunctionSignature.Aggregation(
Expand All @@ -779,7 +812,12 @@ internal class Header(
parameters = listOf(FunctionParameter("value", it)),
isNullable = true,
)
}
} + FunctionSignature.Aggregation(
name = "avg",
returns = ANY,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = true,
)

// ====================================
// SORTING
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ internal object RelConverter {
return Pair(select, input)
}

// Build the schema -> (aggs... groups...)
// Build the schema -> (calls... groups...)
val schema = mutableListOf<Rel.Binding>()
val props = emptySet<Rel.Prop>()

Expand Down
123 changes: 110 additions & 13 deletions partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@ package org.partiql.planner.typer
import org.partiql.errors.Problem
import org.partiql.errors.ProblemCallback
import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION
import org.partiql.plan.Agg
import org.partiql.plan.Fn
import org.partiql.plan.Identifier
import org.partiql.plan.Rel
import org.partiql.plan.Rex
import org.partiql.plan.Statement
import org.partiql.plan.aggResolved
import org.partiql.plan.fnResolved
import org.partiql.plan.identifierSymbol
import org.partiql.plan.rel
import org.partiql.plan.relBinding
import org.partiql.plan.relOpAggregateCall
import org.partiql.plan.relOpErr
import org.partiql.plan.relOpFilter
import org.partiql.plan.relOpJoin
Expand Down Expand Up @@ -293,19 +296,20 @@ internal class PlanTyper(

/**
* Initial implementation of `EXCLUDE` schema inference. Until an RFC is finalized for `EXCLUDE`
* (https://github.com/partiql/partiql-spec/issues/39), this behavior is considered experimental and subject to
* change.
* (https://github.com/partiql/partiql-spec/issues/39),
*
* So far this implementation includes
* This behavior is considered experimental and subject to change.
*
* This implementation includes
* - Excluding tuple bindings (e.g. t.a.b.c)
* - Excluding tuple wildcards (e.g. t.a.*.b)
* - Excluding collection indexes (e.g. t.a[0].b -- behavior subject to change; see below discussion)
* - Excluding collection wildcards (e.g. t.a[*].b)
*
* There are still discussion points regarding the following edge cases:
* - EXCLUDE on a tuple bindingibute that doesn't exist -- give an error/warning?
* - EXCLUDE on a tuple attribute that doesn't exist -- give an error/warning?
* - currently no error
* - EXCLUDE on a tuple bindingibute that has duplicates -- give an error/warning? exclude one? exclude both?
* - EXCLUDE on a tuple attribute that has duplicates -- give an error/warning? exclude one? exclude both?
* - currently excludes both w/ no error
* - EXCLUDE on a collection index as the last step -- mark element type as optional?
* - currently element type as-is
Expand All @@ -315,7 +319,7 @@ internal class PlanTyper(
* - currently a parser error
* - EXCLUDE on a union type -- give an error/warning? no-op? exclude on each type in union?
* - currently exclude on each union type
* - If SELECT list includes an bindingibute that is excluded, we could consider giving an error in PlanTyper or
* - If SELECT list includes an attribute that is excluded, we could consider giving an error in PlanTyper or
* some other semantic pass
* - currently does not give an error
*/
Expand All @@ -333,7 +337,33 @@ internal class PlanTyper(
}

override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: Rel.Type?): Rel {
TODO("Type RelOp Aggregate")
// compute input schema
val input = visitRel(node.input, ctx)

// type the calls and groups
val typer = RexTyper(locals = TypeEnv(input.type.schema, ResolutionStrategy.LOCAL))

// typing of aggregate calls it slightly more complicated because they are not expressions.
val calls = node.calls.mapIndexed { i, call ->
when (val agg = call.agg) {
is Agg.Resolved -> call to ctx!!.schema[i].type
is Agg.Unresolved -> typer.resolveAgg(agg, call.args)
}
}
val groups = node.groups.map { typer.visitRex(it, null) }

// Compute schema using order (calls...groups...)
val schema = mutableListOf<StaticType>()
schema += calls.map { it.second }
schema += groups.map { it.type }

// rewrite with typed calls and groups
val type = ctx!!.copyWithSchema(schema)
val op = node.copy(
calls = calls.map { it.first },
groups = groups,
)
return rel(type, op)
}

override fun visitRelBinding(node: Rel.Binding, ctx: Rel.Type?): Rel {
Expand Down Expand Up @@ -441,18 +471,15 @@ internal class PlanTyper(
// 4. Invalid path reference; always MISSING
if (type == StaticType.MISSING) {
handleAlwaysMissing()
return rex(type, rexOpErr("Unknown identifier $node"))
return rexErr("Unknown identifier $node")
}

// 5. Non-missing, root is resolved
return rex(type, rexOpPath(root, steps))
}

/**
* Typing of functions is
*
* 1. If any argument is MISSING, the function return type is MISSING
* 2. If all arguments are NULL
* Resolve and type scalar function calls.
*
* @param node
* @param ctx
Expand Down Expand Up @@ -522,7 +549,7 @@ internal class PlanTyper(
}
is FnMatch.Error -> {
handleUnknownFunction(match)
rex(StaticType.MISSING, rexOpErr("Unknown scalar function $fn"))
rexErr("Unknown scalar function $fn")
}
}
}
Expand Down Expand Up @@ -755,6 +782,74 @@ internal class PlanTyper(
false -> StaticType.ANY
}
}

/**
* Resolution and typing of aggregation function calls.
*
* I've chosen to place this in RexTyper because all arguments will be typed using the same locals.
* There's no need to create new RexTyper instances for each argument. There is no reason to limit aggregations
* to a single argument (covar, corr, pct, etc.) but in practice we typically only have single <value expression>.
*
* This method is _very_ similar to scalar function resolution, so it is temping to DRY these two out; but the
* separation is cleaner as the typing of NULLS is subtly different.
*
* SQL-99 6.16 General Rules on <set function specification>
* Let TX be the single-column table that is the result of applying the <value expression>
* to each row of T and eliminating null values <--- all NULL values are eliminated as inputs
*/
public fun resolveAgg(agg: Agg.Unresolved, arguments: List<Rex>): Pair<Rel.Op.Aggregate.Call, StaticType> {
var missingArg = false
val args = arguments.map {
val arg = visitRex(it, null)
if (arg.type == MissingType) missingArg = true
arg
}

//
if (missingArg) {
handleAlwaysMissing()
return relOpAggregateCall(agg, listOf(rexErr("MISSING"))) to MissingType
}

// Try to match the arguments to functions defined in the catalog
return when (val match = env.resolveAgg(agg, args)) {
is FnMatch.Ok -> {
// Found a match!
val newAgg = aggResolved(match.signature)
val newArgs = rewriteFnArgs(match.mapping, args)
val returns = newAgg.signature.returns

// Determine the nullability of the return type
var isNullable = false // True iff has a NULLABLE arg and is a NULLABLE operator
if (newAgg.signature.isNullable) {
for (arg in newArgs) {
if (arg.type.isNullable()) {
isNullable = true
break
}
}
}

// Return type with calculated nullability
var type = when {
isNullable -> returns.toStaticType()
else -> returns.toNonNullStaticType()
}

// Some operators can return MISSING during runtime
if (match.isMissable) {
type = StaticType.unionOf(type, StaticType.MISSING)
}

// Finally, rewrite this node
relOpAggregateCall(newAgg, newArgs) to type
}
is FnMatch.Error -> {
handleUnknownFunction(match)
return relOpAggregateCall(agg, listOf(rexErr("MISSING"))) to MissingType
}
}
}
}

// HELPERS
Expand All @@ -763,6 +858,8 @@ internal class PlanTyper(

private fun Rex.type(typeEnv: TypeEnv) = RexTyper(typeEnv).visitRex(this, null)

private fun rexErr(message: String) = rex(StaticType.MISSING, rexOpErr(message))

/**
* I found decorating the tree with the binding names (for resolution) was easier than associating introduced
* bindings with a node via an id->list<string> map. ONLY because right now I don't think we have a good way
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package org.partiql.planner
import com.amazon.ionelement.api.field
import com.amazon.ionelement.api.ionString
import com.amazon.ionelement.api.ionStructOf
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.DynamicContainer
import org.junit.jupiter.api.DynamicContainer.dynamicContainer
import org.junit.jupiter.api.DynamicNode
Expand All @@ -14,6 +13,7 @@ import org.junit.jupiter.api.fail
import org.partiql.errors.ProblemSeverity
import org.partiql.parser.PartiQLParserBuilder
import org.partiql.plan.Statement
import org.partiql.plan.debug.PlanPrinter
import org.partiql.planner.test.PlannerTest
import org.partiql.planner.test.PlannerTestProvider
import org.partiql.planner.test.PlannerTestSuite
Expand Down Expand Up @@ -83,7 +83,15 @@ class PlannerTestJunit {
}
val expected = test.schema.toIon()
val actual = statement.root.type.toIon()
assertEquals(expected, actual)
assert(expected == actual) {
buildString {
appendLine()
appendLine("Expect: $expected")
appendLine("Actual: $actual")
appendLine()
PlanPrinter.append(this, statement)
}
}
}
}
}
Expand Down
Loading

0 comments on commit 5d422da

Please sign in to comment.