Skip to content

Commit

Permalink
[v1] Add defaults for AstVisitor; port partiql-ast normalization pass…
Browse files Browse the repository at this point in the history
…es to partiql-planner
  • Loading branch information
alancai98 committed Oct 25, 2024
1 parent 8751c18 commit 3688156
Show file tree
Hide file tree
Showing 12 changed files with 675 additions and 195 deletions.
194 changes: 98 additions & 96 deletions partiql-ast/api/partiql-ast.api

Large diffs are not rendered by default.

396 changes: 297 additions & 99 deletions partiql-ast/src/main/java/org/partiql/ast/v1/AstVisitor.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import org.partiql.ast.builder.ast
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.StringValue

// TODO DELETE FILE

private val col = { index: () -> Int -> "_${index()}" }

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package org.partiql.ast.normalize

import org.partiql.ast.Statement

// TODO DELETE FILE

/**
* Wraps a rewriter with a default entry point.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package org.partiql.ast.normalize

import org.partiql.ast.Statement

// TODO DELETE FILE

/**
* AST normalization
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.partiql.ast.fromJoin
import org.partiql.ast.helpers.toBinder
import org.partiql.ast.util.AstRewriter

// TODO DELETE FILE

/**
* Assign aliases to any FROM source which does not have one.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import org.partiql.ast.groupByKey
import org.partiql.ast.helpers.toBinder
import org.partiql.ast.util.AstRewriter

// TODO DELETE FILE

/**
* Adds a unique binder to each group key.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package org.partiql.planner.internal.helpers

import org.partiql.ast.v1.Ast.identifier
import org.partiql.ast.v1.Identifier
import org.partiql.ast.v1.IdentifierChain
import org.partiql.ast.v1.expr.Expr
import org.partiql.ast.v1.expr.ExprCast
import org.partiql.ast.v1.expr.ExprLit
import org.partiql.ast.v1.expr.ExprPath
import org.partiql.ast.v1.expr.ExprSessionAttribute
import org.partiql.ast.v1.expr.ExprVarRef
import org.partiql.ast.v1.expr.PathStep
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.StringValue

private val col = { index: () -> Int -> "_${index()}" }

/**
* Produces a "binder" (AS alias) for an expression following the given rules:
*
* 1. If item is an id, use the last symbol
* 2. If item is a path with a final symbol step, use the symbol — else 4
* 3. If item is a cast, use the value name
* 4. Else, use item index with prefix _
*
* See https://github.com/partiql/partiql-lang-kotlin/issues/1122
*/
internal fun Expr.toBinder(index: () -> Int): Identifier = when (this) {
is ExprVarRef -> this.identifierChain.toBinder()
is ExprPath -> this.toBinder(index)
is ExprCast -> this.value.toBinder(index)
is ExprSessionAttribute -> this.sessionAttribute.name().uppercase().toBinder()
else -> col(index).toBinder()
}

/**
* Simple toBinder that uses an int literal rather than a closure.
*
* @param index
* @return
*/
internal fun Expr.toBinder(index: Int): Identifier = toBinder { index }

private fun String.toBinder(): Identifier =
// Every binder preserves case
identifier(this@toBinder, true)

private fun IdentifierChain.toBinder(): Identifier {
if (next == null) return root.symbol.toBinder()
var cur = next
var prev = cur
while (cur != null) {
prev = cur
cur = cur.next
}
return prev!!.root.symbol.toBinder()
}

private fun Identifier.toBinder(): Identifier = symbol.toBinder()

@OptIn(PartiQLValueExperimental::class)
private fun ExprPath.toBinder(index: () -> Int): Identifier {
if (next == null) return root.toBinder(index)
var cur = next
var prev = next
while (cur != null) {
prev = cur
cur = cur.next
}
return when (prev) {
is PathStep.Field -> prev.field.toBinder()
is PathStep.Element -> {
val k = prev.element
if (k is ExprLit && k.value is StringValue) {
(k.value as StringValue).value!!.toBinder()
} else {
col(index).toBinder()
}
}
else -> col(index).toBinder()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.planner.internal.normalize

import org.partiql.ast.v1.Statement

/**
* Wraps a rewriter with a default entry point.
*/
internal interface AstPass {
fun apply(statement: Statement): Statement
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
@file:JvmName("Normalize")
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.planner.internal.normalize

import org.partiql.ast.v1.Statement

/**
* AST normalization
*/
internal fun Statement.normalize(): Statement {
// could be a fold, but this is nice for setting breakpoints
var ast = this
ast = NormalizeFromSource.apply(ast)
ast = NormalizeGroupBy.apply(ast)
return ast
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.planner.internal.normalize

import org.partiql.ast.v1.Ast.fromExpr
import org.partiql.ast.v1.Ast.fromJoin
import org.partiql.ast.v1.AstNode
import org.partiql.ast.v1.AstVisitor
import org.partiql.ast.v1.From
import org.partiql.ast.v1.FromExpr
import org.partiql.ast.v1.FromJoin
import org.partiql.ast.v1.FromTableRef
import org.partiql.ast.v1.FromType
import org.partiql.ast.v1.QueryBody
import org.partiql.ast.v1.Statement
import org.partiql.ast.v1.expr.Expr
import org.partiql.planner.internal.helpers.toBinder

/**
* Assign aliases to any FROM source which does not have one.
*/
internal object NormalizeFromSource : AstPass {

override fun apply(statement: Statement): Statement = Visitor.visitStatement(statement, 0) as Statement

private object Visitor : AstVisitor<AstNode, Int> {

// Each SFW starts the ctx count again.
override fun visitQueryBodySFW(node: QueryBody.SFW, ctx: Int): AstNode = super.visitQueryBodySFW(node, 0)

override fun visitFrom(node: From, ctx: Int) = super.visitFrom(node, ctx) as From

override fun visitFromJoin(node: FromJoin, ctx: Int): FromJoin {
val lhs = visitTableRef(node.lhs, ctx) as FromTableRef
val rhs = visitTableRef(node.rhs, ctx + 1) as FromTableRef
val condition = node.condition?.let { visitExpr(it, ctx) as Expr }
return if (lhs !== node.lhs || rhs !== node.rhs || condition !== node.condition) {
fromJoin(lhs, rhs, node.joinType, condition)
} else {
node
}
}

override fun visitFromExpr(node: FromExpr, ctx: Int): FromExpr {
val expr = visitExpr(node.expr, ctx) as Expr
var i = ctx
var asAlias = node.asAlias
var atAlias = node.atAlias
// derive AS alias
if (asAlias == null) {
asAlias = expr.toBinder(i++)
}
// derive AT binder
if (atAlias == null && node.fromType == FromType.UNPIVOT()) {
atAlias = expr.toBinder(i++)
}
return if (expr !== node.expr || asAlias !== node.asAlias || atAlias !== node.atAlias) {
fromExpr(expr = expr, fromType = node.fromType, asAlias = asAlias, atAlias = atAlias)
} else {
node
}
}

override fun defaultReturn(node: AstNode, ctx: Int) = node
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.planner.internal.normalize

import org.partiql.ast.v1.Ast.groupBy
import org.partiql.ast.v1.Ast.groupByKey
import org.partiql.ast.v1.AstNode
import org.partiql.ast.v1.AstVisitor
import org.partiql.ast.v1.GroupBy
import org.partiql.ast.v1.Statement
import org.partiql.ast.v1.expr.Expr
import org.partiql.planner.internal.helpers.toBinder

/**
* Adds a unique binder to each group key.
*/
internal object NormalizeGroupBy : AstPass {

override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement

private object Visitor : AstVisitor<AstNode, Int> {

override fun visitGroupBy(node: GroupBy, ctx: Int): AstNode {
val keys = node.keys.mapIndexed { index, key ->
visitGroupByKey(key, index + 1)
}
return groupBy(strategy = node.strategy, keys = keys, asAlias = node.asAlias)
}

override fun visitGroupByKey(node: GroupBy.Key, ctx: Int): GroupBy.Key {
val expr = visitExpr(node.expr, 0) as Expr
val alias = when (node.asAlias) {
null -> expr.toBinder(ctx)
else -> node.asAlias
}
return if (expr !== node.expr || alias !== node.asAlias) {
groupByKey(expr, alias)
} else {
node
}
}

override fun defaultReturn(node: AstNode, ctx: Int) = node
}
}

0 comments on commit 3688156

Please sign in to comment.