Skip to content

Commit

Permalink
Addresses PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn committed Oct 12, 2023
1 parent daafef5 commit 5aa2df0
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,19 @@ import org.partiql.ast.From
import org.partiql.ast.Identifier
import org.partiql.ast.Select
import org.partiql.ast.Statement
import org.partiql.ast.builder.AstBuilder
import org.partiql.ast.builder.ast
import org.partiql.ast.exprCall
import org.partiql.ast.exprCase
import org.partiql.ast.exprCaseBranch
import org.partiql.ast.exprIsType
import org.partiql.ast.exprLit
import org.partiql.ast.exprStruct
import org.partiql.ast.exprStructField
import org.partiql.ast.exprVar
import org.partiql.ast.helpers.toBinder
import org.partiql.ast.identifierSymbol
import org.partiql.ast.selectProjectItemExpression
import org.partiql.ast.selectValue
import org.partiql.ast.typeStruct
import org.partiql.ast.util.AstRewriter
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.stringValue
Expand Down Expand Up @@ -103,30 +113,30 @@ internal object NormalizeSelect : AstPass {
*/
private val col = { index: Int -> "_${index + 1}" }

override fun visitExprSFW(node: Expr.SFW, ctx: Int) = ast {
override fun visitExprSFW(node: Expr.SFW, ctx: Int): Expr.SFW {
val sfw = super.visitExprSFW(node, ctx) as Expr.SFW
when (val select = sfw.select) {
return when (val select = sfw.select) {
is Select.Star -> sfw.copy(select = visitSelectAll(select, sfw.from))
else -> sfw
}
}

override fun visitSelectProject(node: Select.Project, ctx: Int): AstNode = ast {
override fun visitSelectProject(node: Select.Project, ctx: Int): AstNode {
val visitedNode = super.visitSelectProject(node, ctx) as? Select.Project
?: error("VisitSelectProject should have returned a Select.Project")
return@ast when (node.items.any { it is Select.Project.Item.All }) {
return when (node.items.any { it is Select.Project.Item.All }) {
false -> visitSelectProjectWithoutProjectAll(visitedNode)
true -> visitSelectProjectWithProjectAll(visitedNode)
}
}

override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: Int) = ast {
override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: Int): Select.Project.Item.Expression {
val expr = visitExpr(node.expr, 0) as Expr
val alias = when (node.asAlias) {
null -> expr.toBinder(ctx)
else -> node.asAlias
}
if (expr != node.expr || alias != node.asAlias) {
return if (expr != node.expr || alias != node.asAlias) {
selectProjectItemExpression(expr, alias)
} else {
node
Expand All @@ -141,7 +151,7 @@ internal object NormalizeSelect : AstPass {
*
* Note: We assume that [select] and [from] have already been visited.
*/
private fun visitSelectAll(select: Select.Star, from: From): Select.Value = ast {
private fun visitSelectAll(select: Select.Star, from: From): Select.Value {
val tupleUnionArgs = from.aliases().flatMapIndexed { i, binding ->
val asAlias = binding.first
val atAlias = binding.second
Expand All @@ -160,16 +170,16 @@ internal object NormalizeSelect : AstPass {
byAliasItem
)
}
selectValue {
constructor = exprCall {
function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE)
args.addAll(tupleUnionArgs)
}
return selectValue(
constructor = exprCall(
function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE),
args = tupleUnionArgs
),
setq = select.setq
}
)
}

private fun visitSelectProjectWithProjectAll(node: Select.Project): AstNode = ast {
private fun visitSelectProjectWithProjectAll(node: Select.Project): AstNode {
val tupleUnionArgs = node.items.mapIndexed { index, item ->
when (item) {
is Select.Project.Item.All -> buildCaseWhenStruct(item.expr, index)
Expand All @@ -180,68 +190,53 @@ internal object NormalizeSelect : AstPass {
)
}
}
selectValue {
setq = node.setq
constructor = exprCall {
function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE)
args.addAll(tupleUnionArgs)
}
}
return selectValue(
setq = node.setq,
constructor = exprCall(
function = identifierSymbol("TUPLEUNION", Identifier.CaseSensitivity.SENSITIVE),
args = tupleUnionArgs
)
)
}

@OptIn(PartiQLValueExperimental::class)
private fun visitSelectProjectWithoutProjectAll(node: Select.Project): AstNode = ast {
private fun visitSelectProjectWithoutProjectAll(node: Select.Project): AstNode {
val structFields = node.items.map { item ->
val itemExpr = item as? Select.Project.Item.Expression ?: error("Expected the projection to be an expression.")
exprStructField(
name = exprLit(stringValue(itemExpr.asAlias?.symbol!!)),
value = item.expr
)
}
selectValue {
setq = node.setq
constructor = exprStruct {
fields.addAll(structFields)
}
}
return selectValue(
setq = node.setq,
constructor = exprStruct(
fields = structFields
)
)
}

@OptIn(PartiQLValueExperimental::class)
private fun buildCaseWhenStruct(expr: Expr, index: Int): Expr.Case {
return ast {
exprCase {
branches.add(
exprCaseBranch(
condition = exprIsType(expr, typeStruct()),
expr = expr
)
)
default = buildSimpleStruct(expr, col(index))
exprStruct {
fields.add(
exprStructField(
name = exprLit(stringValue(index.toString())),
value = expr
)
)
}
}
}
}
private fun buildCaseWhenStruct(expr: Expr, index: Int): Expr.Case = exprCase(
expr = null,
branches = listOf(
exprCaseBranch(
condition = exprIsType(expr, typeStruct(), null),
expr = expr
)
),
default = buildSimpleStruct(expr, col(index))
)

@OptIn(PartiQLValueExperimental::class)
private fun buildSimpleStruct(expr: Expr, name: String): Expr.Struct {
return ast {
exprStruct {
fields.add(
exprStructField(
name = exprLit(stringValue(name)),
value = expr
)
)
}
}
}
private fun buildSimpleStruct(expr: Expr, name: String): Expr.Struct = exprStruct(
fields = listOf(
exprStructField(
name = exprLit(stringValue(name)),
value = expr
)
)
)

private fun From.aliases(): List<Triple<String, String?, String?>> = when (this) {
is From.Join -> lhs.aliases() + rhs.aliases()
Expand All @@ -254,19 +249,19 @@ internal object NormalizeSelect : AstPass {
}

// t -> t.* AS _i
private fun String.star(i: Int) = ast {
val expr = exprVar(id(this@star), Expr.Var.Scope.DEFAULT)
private fun String.star(i: Int): Select.Project.Item.Expression {
val expr = exprVar(id(this), Expr.Var.Scope.DEFAULT)
val alias = expr.toBinder(i)
selectProjectItemExpression(expr, alias)
return selectProjectItemExpression(expr, alias)
}

// t -> t AS t
private fun String.simple() = ast {
val expr = exprVar(id(this@simple), Expr.Var.Scope.DEFAULT)
val alias = id(this@simple)
selectProjectItemExpression(expr, alias)
private fun String.simple(): Select.Project.Item.Expression {
val expr = exprVar(id(this), Expr.Var.Scope.DEFAULT)
val alias = id(this)
return selectProjectItemExpression(expr, alias)
}

private fun AstBuilder.id(symbol: String) = identifierSymbol(symbol, Identifier.CaseSensitivity.INSENSITIVE)
private fun id(symbol: String) = identifierSymbol(symbol, Identifier.CaseSensitivity.INSENSITIVE)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ internal object RelConverter {
}
// SELECT VALUE ... FROM
is Select.Value -> {
val projectionOp = rel.op as? Rel.Op.Project ?: error("SELECT VALUE should have a PROJECT underneath")
assert(projectionOp.projections.size == 1) {
"Expected SELECT VALUE projection to have a single binding. However, it looked like ${projectionOp.projections}"
assert(rel.type.schema.size == 1) {
"Expected SELECT VALUE's input to have a single binding. " +
"However, it contained: ${rel.type.schema.map { it.name }}."
}
val constructor = rex(StaticType.ANY, rexOpVarResolved(0))
val op = rexOpSelect(constructor, rel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,9 @@ internal class PlanTyper(
* ELSE { 'a': a }
* END
* ```
* When we type the above, we can't just assume
* When we type the above, if we know that `a` can be many different types (one of them being a struct),
* then when we see the top-level `a IS STRUCT`, then we can assume that the `a` on the RHS is definitely a
* struct. We handle this by using [handleSmartCasts].
*/
override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: StaticType?): Rex.Op.Case.Branch {
val visitedCondition = visitRex(node.condition, node.condition.type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package org.partiql.planner.typer

import org.partiql.plan.Rex
// import org.partiql.plan.rex
import org.partiql.plan.util.PlanRewriter

/**
Expand Down

0 comments on commit 5aa2df0

Please sign in to comment.