diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 8984a9551b03..b04a6d567e85 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -45,12 +45,14 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec +import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation import org.apache.spark.sql.execution.utils.ExecUtil import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction} import org.apache.spark.sql.hive.VeloxHiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.task.TaskResources import org.apache.commons.lang3.ClassUtils @@ -621,6 +623,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { child: SparkPlan, numOutputRows: SQLMetric, dataSize: SQLMetric): BuildSideRelation = { + val useOffheapBroadcastBuildRelation = + GlutenConfig.getConf.enableBroadcastBuildRelationInOffheap val serialized: Array[ColumnarBatchSerializeResult] = child .executeColumnar() .mapPartitions(itr => Iterator(BroadcastUtils.serializeStream(itr))) @@ -633,7 +637,13 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { } numOutputRows += serialized.map(_.getNumRows).sum dataSize += rawSize - ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode) + if (useOffheapBroadcastBuildRelation) { + TaskResources.runUnsafe { + new UnsafeColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode) + } + } else { + ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode) + } } override def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = { diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala index c5323d4f8d50..491de4b220bd 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.runtime.Runtimes @@ -27,7 +28,8 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning} -import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode, LongHashedRelation} +import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelation, HashedRelationBroadcastMode, LongHashedRelation} +import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.task.TaskResources @@ -45,7 +47,7 @@ object BroadcastUtils { mode match { case HashedRelationBroadcastMode(_, _) => // ColumnarBuildSideRelation to HashedRelation. - val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarBuildSideRelation]] + val fromBroadcast = from.asInstanceOf[Broadcast[BuildSideRelation]] val fromRelation = fromBroadcast.value.asReadOnlyCopy() var rowCount: Long = 0 val toRelation = TaskResources.runUnsafe { @@ -60,7 +62,7 @@ object BroadcastUtils { context.broadcast(toRelation).asInstanceOf[Broadcast[T]] case IdentityBroadcastMode => // ColumnarBuildSideRelation to HashedRelation. - val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarBuildSideRelation]] + val fromBroadcast = from.asInstanceOf[Broadcast[BuildSideRelation]] val fromRelation = fromBroadcast.value.asReadOnlyCopy() val toRelation = TaskResources.runUnsafe { val rowIterator = fn(fromRelation.deserialized) @@ -91,6 +93,7 @@ object BroadcastUtils { schema: StructType, from: Broadcast[F], fn: Iterator[InternalRow] => Iterator[ColumnarBatch]): Broadcast[T] = { + val useOffheapBuildRelation = GlutenConfig.getConf.enableBroadcastBuildRelationInOffheap mode match { case HashedRelationBroadcastMode(_, _) => // HashedRelation to ColumnarBuildSideRelation. @@ -104,10 +107,17 @@ object BroadcastUtils { case result: ColumnarBatchSerializeResult => Array(result.getSerialized) } - ColumnarBuildSideRelation( - SparkShimLoader.getSparkShims.attributesFromStruct(schema), - serialized, - mode) + if (useOffheapBuildRelation) { + new UnsafeColumnarBuildSideRelation( + SparkShimLoader.getSparkShims.attributesFromStruct(schema), + serialized, + mode) + } else { + ColumnarBuildSideRelation( + SparkShimLoader.getSparkShims.attributesFromStruct(schema), + serialized, + mode) + } } // Rebroadcast Velox relation. context.broadcast(toRelation).asInstanceOf[Broadcast[T]] @@ -123,10 +133,17 @@ object BroadcastUtils { case result: ColumnarBatchSerializeResult => Array(result.getSerialized) } - ColumnarBuildSideRelation( - SparkShimLoader.getSparkShims.attributesFromStruct(schema), - serialized, - mode) + if (useOffheapBuildRelation) { + new UnsafeColumnarBuildSideRelation( + SparkShimLoader.getSparkShims.attributesFromStruct(schema), + serialized, + mode) + } else { + ColumnarBuildSideRelation( + SparkShimLoader.getSparkShims.attributesFromStruct(schema), + serialized, + mode) + } } // Rebroadcast Velox relation. context.broadcast(toRelation).asInstanceOf[Broadcast[T]] diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala new file mode 100644 index 000000000000..fb427cbaf37d --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.unsafe + +import org.apache.spark.internal.Logging +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.LongArray + +import java.security.MessageDigest + +/** + * Used to store broadcast variable off-heap memory for broadcast variable. The underlying data + * structure is a LongArray allocated in offheap memory. + * + * @param arraySize + * underlying array[array[byte]]'s length + * @param bytesBufferLengths + * underlying array[array[byte]] per bytesBuffer length + * @param totalBytes + * all bytesBuffer's length plus together + */ +// scalastyle:off no.finalize +case class UnsafeBytesBufferArray( + arraySize: Int, + bytesBufferLengths: Array[Int], + totalBytes: Long, + tmm: TaskMemoryManager) + extends MemoryConsumer(tmm, MemoryMode.OFF_HEAP) + with Logging { + + /** + * A single array to store all bytesBufferArray's value, it's inited once when first time get + * accessed. + */ + private var longArray: LongArray = _ + + /** Index the start of each byteBuffer's offset to underlying LongArray's initial position. */ + private val bytesBufferOffset = new Array[Int](arraySize) + + { + assert(arraySize == bytesBufferLengths.length) + assert(totalBytes >= 0) + if (arraySize > 0) { + for (i <- 0 until arraySize) { + bytesBufferOffset(i) = + if (i > 0) bytesBufferLengths(i - 1) + bytesBufferOffset(i - 1) else 0 + } + } + } + + override def spill(l: Long, memoryConsumer: MemoryConsumer): Long = 0L + + /** + * Put bytesBuffer at specified array index. + * @param index + * @param bytesBuffer + */ + def putBytesBuffer(index: Int, bytesBuffer: Array[Byte]): Unit = this.synchronized { + assert(index < arraySize) + assert(bytesBuffer.length == bytesBufferLengths(index)) + // first to allocate underlying long array + if (null == longArray && index == 0) { + log.debug(s"allocate array $totalBytes, actual longArray size ${(totalBytes + 7) / 8}") + longArray = allocateArray((totalBytes + 7) / 8) + } + if (log.isDebugEnabled) { + log.debug(s"put bytesBuffer at index $index bytesBuffer's length is ${bytesBuffer.length}") + log.debug( + s"bytesBuffer at index $index " + + s"digest ${calculateMD5(bytesBuffer).mkString("Array(", ", ", ")")}") + } + Platform.copyMemory( + bytesBuffer, + Platform.BYTE_ARRAY_OFFSET, + longArray.getBaseObject, + longArray.getBaseOffset + bytesBufferOffset(index), + bytesBufferLengths(index)) + } + + /** + * Get bytesBuffer at specified index. + * @param index + * @return + */ + def getBytesBuffer(index: Int): Array[Byte] = { + assert(index < arraySize) + if (null == longArray) { + return new Array[Byte](0) + } + val bytes = new Array[Byte](bytesBufferLengths(index)) + log.debug(s"get bytesBuffer at index $index bytesBuffer length ${bytes.length}") + Platform.copyMemory( + longArray.getBaseObject, + longArray.getBaseOffset + bytesBufferOffset(index), + bytes, + Platform.BYTE_ARRAY_OFFSET, + bytesBufferLengths(index)) + if (log.isDebugEnabled) { + log.debug( + s"get bytesBuffer at index $index " + + s"digest ${calculateMD5(bytes).mkString("Array(", ", ", ")")}") + } + bytes + } + + /** + * Get the bytesBuffer memory address and length at specified index, usually used when read memory + * direct from offheap. + * + * @param index + * @return + */ + def getBytesBufferOffsetAndLength(index: Int): (Long, Int) = { + assert(index < arraySize) + assert(longArray != null, "The broadcast data in offheap should not be null!") + val offset = longArray.getBaseOffset + bytesBufferOffset(index) + val length = bytesBufferLengths(index) + (offset, length) + } + + /** + * It's needed once the broadcast variable is garbage collected. Since now, we don't have an + * elegant way to free the underlying memory in offheap. + */ + override def finalize(): Unit = { + try { + if (longArray != null) { + log.debug(s"BytesArrayInOffheap finalize $arraySize") + freeArray(longArray) + longArray = null + } + } finally { + super.finalize() + } + } + + /** + * Used to debug input/output bytes. + * + * @param bytesBuffer + * @return + */ + private def calculateMD5(bytesBuffer: Array[Byte]): Array[Byte] = try { + val md = MessageDigest.getInstance("MD5") + md.digest(bytesBuffer) + } catch { + case e: Throwable => + log.warn("error when calculateMD5", e) + new Array[Byte](0) + } +} +// scalastyle:on no.finalize diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala new file mode 100644 index 000000000000..d0b8f295d91d --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.unsafe + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.columnarbatch.ColumnarBatches +import org.apache.gluten.iterator.Iterators +import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators +import org.apache.gluten.runtime.Runtimes +import org.apache.gluten.sql.shims.SparkShimLoader +import org.apache.gluten.utils.ArrowAbiUtil +import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper} + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode} +import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.utils.SparkArrowUtil +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.task.TaskResources +import org.apache.spark.util.Utils + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.arrow.c.ArrowSchema + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import scala.collection.JavaConverters.asScalaIteratorConverter + +/** + * UnsafeColumnarBuildSideRelation should backed by offheap to avoid on-heap oom. Almost the same as + * ColumnarBuildSideRelation, we should remove ColumnarBuildSideRelation when + * UnsafeColumnarBuildSideRelation get matured. + * + * @param output + * @param batches + */ +case class UnsafeColumnarBuildSideRelation( + private var output: Seq[Attribute], + private var batches: UnsafeBytesBufferArray, + var mode: BroadcastMode) + extends BuildSideRelation + with Externalizable + with Logging + with KryoSerializable { + + // Needed for serialization + def this() = { + this(null, null.asInstanceOf[UnsafeBytesBufferArray], null) + } + + def this(output: Seq[Attribute], bytesBufferArray: Array[Array[Byte]], mode: BroadcastMode) = { + // only used in driver side when broadcast the whole batches + this( + output, + UnsafeBytesBufferArray( + bytesBufferArray.length, + bytesBufferArray.map(_.length), + bytesBufferArray.map(_.length.toLong).sum, + TaskContext.get().taskMemoryManager + ), + mode + ) + val batchesSize = bytesBufferArray.length + for (i <- 0 until batchesSize) { + val length = bytesBufferArray(i).length + log.debug(s"this $i--- $length") + batches.putBytesBuffer(i, bytesBufferArray(i)) + } + } + + // should only be used on driver to serialize this relation + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + out.writeObject(output) + out.writeObject(mode) + out.writeInt(batches.arraySize) + out.writeObject(batches.bytesBufferLengths) + out.writeLong(batches.totalBytes) + for (i <- 0 until batches.arraySize) { + val bytes = batches.getBytesBuffer(i) + out.write(bytes) + log.debug(s"writeExternal index $i with length ${bytes.length}") + } + } + + // should only be used on driver to serialize this relation + override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException { + kryo.writeObject(out, output.toList) + kryo.writeObject(out, mode) + out.writeInt(batches.arraySize) + kryo.writeObject(out, batches.bytesBufferLengths) + out.writeLong(batches.totalBytes) + for (i <- 0 until batches.arraySize) { + val bytes = batches.getBytesBuffer(i) + out.write(bytes) + log.debug(s"write index $i with length ${bytes.length}") + } + } + + // should only be used on executor to deserialize this relation + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + output = in.readObject().asInstanceOf[Seq[Attribute]] + mode = in.readObject().asInstanceOf[BroadcastMode] + val totalArraySize = in.readInt() + val bytesBufferLengths = in.readObject().asInstanceOf[Array[Int]] + val totalBytes = in.readLong() + + val taskMemoryManager = new TaskMemoryManager( + new UnifiedMemoryManager(SparkEnv.get.conf, Long.MaxValue, Long.MaxValue / 2, 1), + 0) + + batches = + UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, totalBytes, taskMemoryManager) + + for (i <- 0 until totalArraySize) { + val length = bytesBufferLengths(i) + log.debug(s"readExternal $i--- ${bytesBufferLengths(i)}") + val tmpBuffer = new Array[Byte](length) + in.read(tmpBuffer) + batches.putBytesBuffer(i, tmpBuffer) + } + } + + override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { + output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]] + mode = kryo.readObject(in, classOf[BroadcastMode]) + val totalArraySize = in.readInt() + val bytesBufferLengths = kryo.readObject(in, classOf[Array[Int]]) + val totalBytes = in.readLong() + + val taskMemoryManager = new TaskMemoryManager( + new UnifiedMemoryManager(SparkEnv.get.conf, Long.MaxValue, Long.MaxValue / 2, 1), + 0) + + batches = + UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, totalBytes, taskMemoryManager) + + for (i <- 0 until totalArraySize) { + val length = bytesBufferLengths(i) + log.debug(s"readExternal $i--- $length") + val tmpBuffer = new Array[Byte](length) + in.read(tmpBuffer) + batches.putBytesBuffer(i, tmpBuffer) + } + } + + private def transformProjection: UnsafeProjection = { + mode match { + case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k) + case IdentityBroadcastMode => UnsafeProjection.create(output, output) + } + } + + override def deserialized: Iterator[ColumnarBatch] = { + val runtime = + Runtimes.contextInstance(BackendsApiManager.getBackendName, "BuildSideRelation#transform") + val jniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime) + val serializeHandle: Long = { + val allocator = ArrowBufferAllocators.contextInstance() + val cSchema = ArrowSchema.allocateNew(allocator) + val arrowSchema = SparkArrowUtil.toArrowSchema( + SparkShimLoader.getSparkShims.structFromAttributes(output), + SQLConf.get.sessionLocalTimeZone) + ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema) + val handle = jniWrapper + .init(cSchema.memoryAddress()) + cSchema.close() + handle + } + + Iterators + .wrap(new Iterator[ColumnarBatch] { + var batchId = 0 + + override def hasNext: Boolean = { + batchId < batches.arraySize + } + + override def next: ColumnarBatch = { + val (offset, length) = + batches.getBytesBufferOffsetAndLength(batchId) + batchId += 1 + val handle = + jniWrapper.deserializeDirectAddress(serializeHandle, offset, length) + ColumnarBatches.create(handle) + } + }) + .protectInvocationFlow() + .recycleIterator { + jniWrapper.close(serializeHandle) + } + .recyclePayload(ColumnarBatches.forceClose) // FIXME why force close? + .create() + } + + override def asReadOnlyCopy(): UnsafeColumnarBuildSideRelation = this + + /** + * Transform columnar broadcast value to Array[InternalRow] by key and distinct. NOTE: This method + * was called in Spark Driver, should manage resources carefully. + */ + override def transform(key: Expression): Array[InternalRow] = TaskResources.runUnsafe { + val runtime = + Runtimes.contextInstance(BackendsApiManager.getBackendName, "BuildSideRelation#transform") + // This transformation happens in Spark driver, thus resources can not be managed automatically. + val serializerJniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime) + val serializeHandle = { + val allocator = ArrowBufferAllocators.contextInstance() + val cSchema = ArrowSchema.allocateNew(allocator) + val arrowSchema = SparkArrowUtil.toArrowSchema( + SparkShimLoader.getSparkShims.structFromAttributes(output), + SQLConf.get.sessionLocalTimeZone) + ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema) + val handle = serializerJniWrapper.init(cSchema.memoryAddress()) + cSchema.close() + handle + } + + var closed = false + + val proj = UnsafeProjection.create(Seq(key)) + + // Convert columnar to Row. + val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime) + val c2rId = jniWrapper.nativeColumnarToRowInit() + var batchId = 0 + val iterator = if (batches.arraySize > 0) { + val res: Iterator[Iterator[InternalRow]] = new Iterator[Iterator[InternalRow]] { + override def hasNext: Boolean = { + val itHasNext = batchId < batches.arraySize + if (!itHasNext && !closed) { + jniWrapper.nativeClose(c2rId) + serializerJniWrapper.close(serializeHandle) + closed = true + } + itHasNext + } + + override def next(): Iterator[InternalRow] = { + val (offset, length) = batches.getBytesBufferOffsetAndLength(batchId) + batchId += 1 + val batchHandle = + serializerJniWrapper.deserializeDirectAddress(serializeHandle, offset, length) + val batch = ColumnarBatches.create(batchHandle) + if (batch.numRows == 0) { + batch.close() + Iterator.empty + } else if (output.isEmpty) { + val rows = ColumnarBatches.emptyRowIterator(batch.numRows()).asScala + batch.close() + rows + } else { + val cols = batch.numCols() + val rows = batch.numRows() + var info = + jniWrapper.nativeColumnarToRowConvert( + c2rId, + ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch), + 0) + batch.close() + + new Iterator[InternalRow] { + var rowId = 0 + var baseLength = 0 + val row = new UnsafeRow(cols) + + override def hasNext: Boolean = { + rowId < rows + } + + override def next: UnsafeRow = { + if (rowId >= rows) throw new NoSuchElementException + if (rowId == baseLength + info.lengths.length) { + baseLength += info.lengths.length + info = jniWrapper.nativeColumnarToRowConvert(batchHandle, c2rId, rowId) + } + val (offset, length) = + (info.offsets(rowId - baseLength), info.lengths(rowId - baseLength)) + row.pointTo(null, info.memoryAddress + offset, length.toInt) + rowId += 1 + row + } + }.map(transformProjection).map(proj).map(_.copy()) + } + } + } + res.flatten + } else { + Iterator.empty + } + iterator.toArray + } +} diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala index b58ea4d3974f..80706ff46f9f 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxHashJoinSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf @@ -114,71 +115,84 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite { } test("Reuse broadcast exchange for different build keys with same table") { - withTable("t1", "t2") { - spark.sql(""" - |CREATE TABLE t1 USING PARQUET - |AS SELECT id as c1, id as c2 FROM range(10) - |""".stripMargin) - - spark.sql(""" - |CREATE TABLE t2 USING PARQUET - |AS SELECT id as c1, id as c2 FROM range(3) - |""".stripMargin) - - val df = spark.sql(""" - |SELECT * FROM t1 - |JOIN t2 as tmp1 ON t1.c1 = tmp1.c1 and tmp1.c1 = tmp1.c2 - |JOIN t2 as tmp2 on t1.c2 = tmp2.c2 and tmp2.c1 = tmp2.c2 - |""".stripMargin) - - assert(collect(df.queryExecution.executedPlan) { - case b: BroadcastExchangeExec => b - }.size == 2) - - checkAnswer( - df, - Row(2, 2, 2, 2, 2, 2) :: Row(1, 1, 1, 1, 1, 1) :: Row(0, 0, 0, 0, 0, 0) :: Nil) - - assert(collect(df.queryExecution.executedPlan) { - case b: ColumnarBroadcastExchangeExec => b - }.size == 1) - assert(collect(df.queryExecution.executedPlan) { - case r @ ReusedExchangeExec(_, _: ColumnarBroadcastExchangeExec) => r - }.size == 1) + for (enabledOffheapBroadcast <- Seq("true", "false")) { + withSQLConf( + GlutenConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key -> enabledOffheapBroadcast) { + withTable("t1", "t2") { + spark.sql(""" + |CREATE TABLE t1 USING PARQUET + |AS SELECT id as c1, id as c2 FROM range(10) + |""".stripMargin) + + spark.sql(""" + |CREATE TABLE t2 USING PARQUET + |AS SELECT id as c1, id as c2 FROM range(3) + |""".stripMargin) + + val df = spark.sql(""" + |SELECT * FROM t1 + |JOIN t2 as tmp1 ON t1.c1 = tmp1.c1 and tmp1.c1 = tmp1.c2 + |JOIN t2 as tmp2 on t1.c2 = tmp2.c2 and tmp2.c1 = tmp2.c2 + |""".stripMargin) + + assert(collect(df.queryExecution.executedPlan) { + case b: BroadcastExchangeExec => b + }.size == 2) + + checkAnswer( + df, + Row(2, 2, 2, 2, 2, 2) :: Row(1, 1, 1, 1, 1, 1) :: Row(0, 0, 0, 0, 0, 0) :: Nil) + + assert(collect(df.queryExecution.executedPlan) { + case b: ColumnarBroadcastExchangeExec => b + }.size == 1) + assert(collect(df.queryExecution.executedPlan) { + case r @ ReusedExchangeExec(_, _: ColumnarBroadcastExchangeExec) => r + }.size == 1) + } + } } } test("ColumnarBuildSideRelation transform support multiple key columns") { - withTable("t1", "t2") { - val df1 = - (0 until 50).map(i => (i % 2, i % 3, s"${i % 25}")).toDF("t1_c1", "t1_c2", "date").as("df1") - val df2 = (0 until 50) - .map(i => (i % 11, i % 13, s"${i % 10}")) - .toDF("t2_c1", "t2_c2", "date") - .as("df2") - df1.write.partitionBy("date").saveAsTable("t1") - df2.write.partitionBy("date").saveAsTable("t2") - - val df = sql(""" - |SELECT t1.date, t1.t1_c1, t2.t2_c2 - |FROM t1 - |JOIN t2 ON t1.date = t2.date - |WHERE t1.date=if(3 <= t2.t2_c2, if(3 < t2.t2_c1, 3, t2.t2_c1), t2.t2_c2) - |ORDER BY t1.date DESC, t1.t1_c1 DESC, t2.t2_c2 DESC - |LIMIT 1 - |""".stripMargin) - - checkAnswer(df, Row("3", 1, 4) :: Nil) - // collect the DPP plan. - val subqueryBroadcastExecs = collectWithSubqueries(df.queryExecution.executedPlan) { - case subqueryBroadcast: ColumnarSubqueryBroadcastExec => subqueryBroadcast + for (enabledOffheapBroadcast <- Seq("true", "false")) { + withSQLConf( + GlutenConfig.VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP.key -> enabledOffheapBroadcast) { + withTable("t1", "t2") { + val df1 = + (0 until 50) + .map(i => (i % 2, i % 3, s"${i % 25}")) + .toDF("t1_c1", "t1_c2", "date") + .as("df1") + val df2 = (0 until 50) + .map(i => (i % 11, i % 13, s"${i % 10}")) + .toDF("t2_c1", "t2_c2", "date") + .as("df2") + df1.write.partitionBy("date").saveAsTable("t1") + df2.write.partitionBy("date").saveAsTable("t2") + + val df = sql(""" + |SELECT t1.date, t1.t1_c1, t2.t2_c2 + |FROM t1 + |JOIN t2 ON t1.date = t2.date + |WHERE t1.date=if(3 <= t2.t2_c2, if(3 < t2.t2_c1, 3, t2.t2_c1), t2.t2_c2) + |ORDER BY t1.date DESC, t1.t1_c1 DESC, t2.t2_c2 DESC + |LIMIT 1 + |""".stripMargin) + + checkAnswer(df, Row("3", 1, 4) :: Nil) + // collect the DPP plan. + val subqueryBroadcastExecs = collectWithSubqueries(df.queryExecution.executedPlan) { + case subqueryBroadcast: ColumnarSubqueryBroadcastExec => subqueryBroadcast + } + assert(subqueryBroadcastExecs.size == 2) + val buildKeysAttrs = subqueryBroadcastExecs + .flatMap(_.buildKeys) + .map(e => e.collect { case a: AttributeReference => a }) + // the buildKeys function can accept expressions with multiple columns. + assert(buildKeysAttrs.exists(_.size > 1)) + } } - assert(subqueryBroadcastExecs.size == 2) - val buildKeysAttrs = subqueryBroadcastExecs - .flatMap(_.buildKeys) - .map(e => e.collect { case a: AttributeReference => a }) - // the buildKeys function can accept expressions with multiple columns. - assert(buildKeysAttrs.exists(_.size > 1)) } } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala index 44ffae45ad6f..99cccf1a6d42 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala @@ -305,6 +305,18 @@ class VeloxTPCHV1BhjSuite extends VeloxTPCHSuite { } } +/** BroadcastBuildSideRelation use offheap. */ +class VeloxTPCHV1BhjOffheapSuite extends VeloxTPCHSuite { + override def subType(): String = "v1-bhj-offheap" + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.sql.sources.useV1SourceList", "parquet") + .set("spark.sql.autoBroadcastJoinThreshold", "30M") + .set("spark.gluten.velox.BroadcastBuildRelationUseOffheap.enabled", "true") + } +} + class VeloxTPCHV2Suite extends VeloxTPCHSuite { override def subType(): String = "v2" diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index 61cedf85763e..df96db350100 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -1194,6 +1194,23 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ColumnarBatchSerialize JNI_METHOD_END(kInvalidObjectHandle) } +JNIEXPORT jlong JNICALL +Java_org_apache_gluten_vectorized_ColumnarBatchSerializerJniWrapper_deserializeDirectAddress( // NOLINT + JNIEnv* env, + jobject wrapper, + jlong serializerHandle, + jlong address, + jint size) { + JNI_METHOD_START + auto ctx = gluten::getRuntime(env, wrapper); + + auto serializer = ObjectStore::retrieve(serializerHandle); + GLUTEN_DCHECK(serializer != nullptr, "ColumnarBatchSerializer cannot be null"); + auto batch = serializer->deserialize((uint8_t*)address, size); + return ctx->saveObject(batch); + JNI_METHOD_END(kInvalidObjectHandle) +} + JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ColumnarBatchSerializerJniWrapper_close( // NOLINT JNIEnv* env, jobject wrapper, diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java index 000e233d5a79..995e68c120c8 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java @@ -42,5 +42,8 @@ public long rtHandle() { public native long deserialize(long serializerHandle, byte[] data); + // Return the native ColumnarBatch handle using memory address and length + public native long deserializeDirectAddress(long serializerHandle, long offset, int len); + public native void close(long serializerHandle); } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala index e9dbeb560c68..bb3c79570340 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala @@ -41,5 +41,5 @@ trait BuildSideRelation extends Serializable { * * Post-processed relation transforms can use this mode to obtain the desired format. */ - val mode: BroadcastMode + def mode: BroadcastMode } diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 15704f1450ee..536ac48f92d5 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -485,6 +485,9 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableCelebornFallback: Boolean = conf.getConf(CELEBORN_FALLBACK_ENABLED) def enableHdfsViewfs: Boolean = conf.getConf(HDFS_VIEWFS_ENABLED) + + def enableBroadcastBuildRelationInOffheap: Boolean = + conf.getConf(VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP) } object GlutenConfig { @@ -2246,4 +2249,12 @@ object GlutenConfig { .doc("If enabled, gluten will convert the viewfs path to hdfs path in scala side") .booleanConf .createWithDefault(false) + + val VELOX_BROADCAST_BUILD_RELATION_USE_OFFHEAP = + buildConf("spark.gluten.velox.BroadcastBuildRelationUseOffheap.enabled") + .internal() + .doc("If enabled, broadcast build relation will use offheap memory. " + + "Otherwise, broadcast build relation will use onheap memory.") + .booleanConf + .createWithDefault(true) }