From 3e50ce140d0cc7e0adac71b05bdf3c698f3d677f Mon Sep 17 00:00:00 2001 From: XiaolongXie Date: Fri, 22 Oct 2021 15:13:25 +0800 Subject: [PATCH] Support complext expression for group by key and distinct (#262) --- .../edu/berkeley/cs/rise/opaque/Utils.scala | 33 ++++++++++++++++++- .../berkeley/cs/rise/opaque/strategies.scala | 19 +++++------ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 87650985f8..dc83bc1f73 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -1065,6 +1065,27 @@ object Utils extends Logging { op(fromChildren, tree) } + def expressionTreeFold[BaseType <: TreeNode[Expression], B]( + tree: Expression, input: Seq[Attribute] + )(op: (Seq[B], Expression) => B): B = { + + /* + * When we find that the tree is a member of input, we stop the expression unfold. + */ + if (input.exists(_.semanticEquals(tree))) { + return op(Nil, tree) + } + + if (tree.isInstanceOf[Alias]) { + if (input.exists(_.semanticEquals(tree.asInstanceOf[Alias].toAttribute))) { + return op(Nil, tree) + } + } + + val fromChildren: Seq[B] = tree.children.map(c => expressionTreeFold(c, input)(op)) + op(fromChildren, tree) + } + def getColType(dataType: DataType) = { dataType match { case IntegerType => tuix.ColType.IntegerType @@ -1083,8 +1104,18 @@ object Utils extends Logging { expr: Expression, input: Seq[Attribute] ): Int = { - treeFold[Expression, Int](expr) { (childrenOffsets, expr) => + expressionTreeFold[Expression, Int](expr, input) { (childrenOffsets, expr) => (expr, childrenOffsets) match { + case (alias_expr: Alias, Nil) if input.exists(_.semanticEquals(alias_expr.toAttribute)) => + val alias_attr = alias_expr.toAttribute + val colNum = input.indexWhere(_.semanticEquals(alias_attr)) + assert(colNum != -1) + + tuix.Expr.createExpr( + builder, + tuix.ExprUnion.Col, + tuix.Col.createCol(builder, colNum)) + case (ar: AttributeReference, Nil) if input.exists(_.semanticEquals(ar)) => val colNum = input.indexWhere(_.semanticEquals(ar)) assert(colNum != -1) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index e83f6235ca..d6a989119f 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -238,17 +238,14 @@ object OpaqueOperators extends Strategy with JoinSelectionHelper { // We need to extract named expressions from the children of the distinct aggregate functions // in order to group by those columns. - val namedDistinctExpressions = - functionsWithDistinct.head.aggregateFunction.children.flatMap { e => - e match { - case ne: NamedExpression => - Seq(ne) - case _ => - e.children - .filter(child => child.isInstanceOf[NamedExpression]) - .map(child => child.asInstanceOf[NamedExpression]) - } - } + val namedDistinctExpressions = functionsWithDistinct.head.aggregateFunction.children.flatMap { + e => { + val leaves = e.collectLeaves() + leaves.filter(leaf => leaf.isInstanceOf[NamedExpression]) + .map(leaf => leaf.asInstanceOf[NamedExpression]) + } + } + val combinedGroupingExpressions = groupingExpressions ++ namedDistinctExpressions // 1. Create an Aggregate operator for partial aggregations.