From 912010155b982b95894ffa8cd1b1ed0ed8b225bb Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 15 Oct 2024 09:27:56 +0800 Subject: [PATCH] [GLUTEN-7526][VL] Scala code style for VeloxCollect (#7527) Closes #7526 --- .../expression/aggregate/VeloxCollect.scala | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala index c12aeab26e70..c35020fab347 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala @@ -21,50 +21,55 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types.{ArrayType, DataType} -abstract class VeloxCollect extends DeclarativeAggregate with UnaryLike[Expression] { +abstract class VeloxCollect(child: Expression) + extends DeclarativeAggregate + with UnaryLike[Expression] { + protected lazy val buffer: AttributeReference = AttributeReference("buffer", dataType)() override def dataType: DataType = ArrayType(child.dataType, false) - override def aggBufferAttributes: Seq[AttributeReference] = List(buffer) + override def aggBufferAttributes: Seq[AttributeReference] = Seq(buffer) - override lazy val initialValues: Seq[Expression] = List(Literal.create(Seq.empty, dataType)) + override lazy val initialValues: Seq[Expression] = Seq(Literal.create(Array(), dataType)) - override lazy val updateExpressions: Seq[Expression] = List( + override lazy val updateExpressions: Seq[Expression] = Seq( If( IsNull(child), buffer, - Concat(List(buffer, CreateArray(List(child), useStringTypeWhenEmpty = false)))) + Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false)))) ) - override lazy val mergeExpressions: Seq[Expression] = List( - Concat(List(buffer.left, buffer.right)) + override lazy val mergeExpressions: Seq[Expression] = Seq( + Concat(Seq(buffer.left, buffer.right)) ) override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType)) } -case class VeloxCollectSet(override val child: Expression) extends VeloxCollect { - override def prettyName: String = "velox_collect_set" +case class VeloxCollectSet(child: Expression) extends VeloxCollect(child) { // Velox's collect_set implementation allows null output. Thus we usually wrap // the function to enforce non-null output. See CollectRewriteRule#ensureNonNull. override def nullable: Boolean = true - override protected def withNewChildInternal(newChild: Expression): Expression = - copy(child = newChild) - override lazy val evaluateExpression: Expression = ArrayDistinct(buffer) + + override def prettyName: String = "velox_collect_set" + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) } -case class VeloxCollectList(override val child: Expression) extends VeloxCollect { - override def prettyName: String = "velox_collect_list" +case class VeloxCollectList(child: Expression) extends VeloxCollect(child) { override def nullable: Boolean = false + override val evaluateExpression: Expression = buffer + + override def prettyName: String = "velox_collect_list" + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild) - - override val evaluateExpression: Expression = buffer }