diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index f7066958bb40..9e3789b9b633 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -24,7 +24,8 @@ import org.apache.gluten.expression._ import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet} import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.sql.shims.SparkShimLoader -import org.apache.gluten.vectorized.{ColumnarBatchSerializeResult, ColumnarBatchSerializer} +import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult} + import org.apache.spark.{ShuffleDependency, SparkException} import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEvalPythonPreProjectHelper} import org.apache.spark.rdd.RDD @@ -44,17 +45,18 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec +import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation import org.apache.spark.sql.execution.utils.ExecUtil import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction} import org.apache.spark.sql.hive.VeloxHiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.commons.lang3.ClassUtils -import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation import org.apache.spark.task.TaskResources import org.apache.spark.util.TaskResources +import org.apache.commons.lang3.ClassUtils + import javax.ws.rs.core.UriBuilder class VeloxSparkPlanExecApi extends SparkPlanExecApi { @@ -621,8 +623,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { child: SparkPlan, numOutputRows: SQLMetric, dataSize: SQLMetric): BuildSideRelation = { - val useOffheapBroadcastBuildRelation = GlutenConfig.getConf - .enableBroadcastBuildRelationInOffheap + val useOffheapBroadcastBuildRelation = + GlutenConfig.getConf.enableBroadcastBuildRelationInOffheap val serialized: Array[ColumnarBatchSerializeResult] = child .executeColumnar() .mapPartitions(itr => Iterator(BroadcastUtils.serializeStream(itr))) @@ -635,7 +637,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { } numOutputRows += serialized.map(_.getNumRows).sum dataSize += rawSize - if (useOffheapBroadcastBuildRelation){ + if (useOffheapBroadcastBuildRelation) { TaskResources.runUnsafe { new UnsafeColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized)) } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala index 71d8a6b2137e..9a3f6fcd0b70 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala @@ -21,6 +21,7 @@ import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.runtime.Runtimes import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.vectorized.{ColumnarBatchSerializeResult, ColumnarBatchSerializerJniWrapper} + import org.apache.spark.SparkContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow @@ -134,9 +135,9 @@ object BroadcastUtils { SparkShimLoader.getSparkShims.attributesFromStruct(schema), serialized) } else { - ColumnarBuildSideRelation( - SparkShimLoader.getSparkShims.attributesFromStruct(schema), - serialized) + ColumnarBuildSideRelation( + SparkShimLoader.getSparkShims.attributesFromStruct(schema), + serialized) } } // Rebroadcast Velox relation. diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala index db05d98e0408..5dc0e62c318c 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeBytesBufferArray.scala @@ -24,32 +24,31 @@ import org.apache.spark.unsafe.array.LongArray import java.security.MessageDigest /** - * Used to store broadcast variable off-heap memory for broadcast variable. - * The underlying data structure is a + * Used to store broadcast variable off-heap memory for broadcast variable. The underlying data + * structure is a * * @param arraySize - * underlying array[array[byte]]'s length + * underlying array[array[byte]]'s length * @param bytesBufferLengths - * underlying array[array[byte]] per bytesBuffer length + * underlying array[array[byte]] per bytesBuffer length * @param totalBytes - * all bytesBuffer's length plus together + * all bytesBuffer's length plus together */ case class UnsafeBytesBufferArray( - arraySize: Int, - bytesBufferLengths: Array[Int], - totalBytes: Long, - tmm: TaskMemoryManager) + arraySize: Int, + bytesBufferLengths: Array[Int], + totalBytes: Long, + tmm: TaskMemoryManager) extends MemoryConsumer(tmm, MemoryMode.OFF_HEAP) - with Logging { + with Logging { + /** - * A single array to store all bytesBufferArray's value, it's inited once - * when first time get accessed. + * A single array to store all bytesBufferArray's value, it's inited once when first time get + * accessed. */ private var longArray: LongArray = _ - /** - * Index the start of each byteBuffer's offset to underlying LongArray's initial position. - */ + /** Index the start of each byteBuffer's offset to underlying LongArray's initial position. */ private val bytesBufferOffset = new Array[Int](arraySize) { @@ -137,17 +136,17 @@ case class UnsafeBytesBufferArray( * It's needed once the broadcast variable is garbage collected. Since now, we don't have an * elegant way to free the underlying memory in offheap. */ - override def finalize(): Unit = { - try { - if (longArray != null) { - log.debug(s"BytesArrayInOffheap finalize $arraySize") - freeArray(longArray) - longArray = null - } - } finally { - super.finalize() + override def finalize(): Unit = { + try { + if (longArray != null) { + log.debug(s"BytesArrayInOffheap finalize $arraySize") + freeArray(longArray) + longArray = null } + } finally { + super.finalize() } + } /** * Used to debug input/output bytes. diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala index e4b8a0c9dd87..beeab4188ef5 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala @@ -16,9 +16,6 @@ */ package org.apache.spark.sql.execution.unsafe -import com.esotericsoftware.kryo.io.{Input, Output} -import com.esotericsoftware.kryo.{Kryo, KryoSerializable} -import org.apache.arrow.c.ArrowSchema import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.iterator.Iterators import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators @@ -26,6 +23,7 @@ import org.apache.gluten.runtime.Runtimes import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.ArrowAbiUtil import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper} + import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} @@ -38,24 +36,29 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.task.TaskResources import org.apache.spark.util.Utils +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.arrow.c.ArrowSchema + import java.io.{Externalizable, ObjectInput, ObjectOutput} + import scala.collection.JavaConverters.asScalaIteratorConverter /** - * UnsafeColumnarBuildSideRelation should backed by offheap to avoid on-heap oom. - * Almost the same as ColumnarBuildSideRelation, we should remove ColumnarBuildSideRelation when + * UnsafeColumnarBuildSideRelation should backed by offheap to avoid on-heap oom. Almost the same as + * ColumnarBuildSideRelation, we should remove ColumnarBuildSideRelation when * UnsafeColumnarBuildSideRelation get matured. * * @param output * @param batches */ case class UnsafeColumnarBuildSideRelation( - private var output: Seq[Attribute], - private var batches: UnsafeBytesBufferArray) + private var output: Seq[Attribute], + private var batches: UnsafeBytesBufferArray) extends BuildSideRelation - with Externalizable - with Logging - with KryoSerializable { + with Externalizable + with Logging + with KryoSerializable { def this(output: Seq[Attribute], bytesBufferArray: Array[Array[Byte]]) { // only used in driver side when broadcast the whole batches @@ -113,11 +116,8 @@ case class UnsafeColumnarBuildSideRelation( new UnifiedMemoryManager(SparkEnv.get.conf, Long.MaxValue, Long.MaxValue / 2, 1), 0) - batches = UnsafeBytesBufferArray( - totalArraySize, - bytesBufferLengths, - totalBytes, - taskMemoryManager) + batches = + UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, totalBytes, taskMemoryManager) for (i <- 0 until totalArraySize) { val length = bytesBufferLengths(i) @@ -138,11 +138,8 @@ case class UnsafeColumnarBuildSideRelation( new UnifiedMemoryManager(SparkEnv.get.conf, Long.MaxValue, Long.MaxValue / 2, 1), 0) - batches = UnsafeBytesBufferArray( - totalArraySize, - bytesBufferLengths, - totalBytes, - taskMemoryManager) + batches = + UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, totalBytes, taskMemoryManager) for (i <- 0 until totalArraySize) { val length = bytesBufferLengths(i) @@ -153,7 +150,6 @@ case class UnsafeColumnarBuildSideRelation( } } - override def deserialized: Iterator[ColumnarBatch] = { val runtime = Runtimes.contextInstance("UnsafeBuildSideRelation#deserialized") val jniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime) @@ -179,10 +175,11 @@ case class UnsafeColumnarBuildSideRelation( } override def next: ColumnarBatch = { - val handle = - jniWrapper - .deserialize(serializeHandle, batches.getBytesBuffer(batchId)) + val (offset, length) = + batches.getBytesBufferOffsetAndLength(batchId) batchId += 1 + val handle = + jniWrapper.deserialize(serializeHandle, offset, length) ColumnarBatches.create(handle) } }) @@ -247,10 +244,10 @@ case class UnsafeColumnarBuildSideRelation( } override def next(): Iterator[InternalRow] = { - val batchBytes = batches.getBytesBuffer(batchId) + val (offset, length) = batches.getBytesBufferOffsetAndLength(batchId) batchId += 1 val batchHandle = - serializerJniWrapper.deserialize(serializeHandle, batchBytes) + serializerJniWrapper.deserialize(serializeHandle, offset, length) val batch = ColumnarBatches.create(batchHandle) if (batch.numRows == 0) { batch.close() diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index 963440f6fc16..a94ad6bd9367 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -1128,6 +1128,22 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ColumnarBatchSerialize JNI_METHOD_END(kInvalidObjectHandle) } +JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ColumnarBatchSerializerJniWrapper_deserialize( // NOLINT + JNIEnv* env, + jobject wrapper, + jlong serializerHandle, + jlong address, + jint size) { + JNI_METHOD_START + auto ctx = gluten::getRuntime(env, wrapper); + + auto serializer = ctx->objectStore()->retrieve(serializerHandle); + GLUTEN_DCHECK(serializer != nullptr, "ColumnarBatchSerializer cannot be null"); + auto batch = serializer->deserialize((uint8_t*) address, size); + return ctx->saveObject(batch); + JNI_METHOD_END(kInvalidResourceHandle) +} + JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ColumnarBatchSerializerJniWrapper_close( // NOLINT JNIEnv* env, jobject wrapper, diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java index 000e233d5a79..9f6ddfbc8ce6 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ColumnarBatchSerializerJniWrapper.java @@ -42,5 +42,8 @@ public long rtHandle() { public native long deserialize(long serializerHandle, byte[] data); + // Return the native ColumnarBatch handle using memory address and length + public native long deserialize(long serializerHandle, long offset, int len); + public native void close(long serializerHandle); }