Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gluten-7750][VL] Move ColumnarBuildSideRelation's memory occupation to Spark off-heap #8127

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.execution.utils.ExecUtil
import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction}
import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.task.TaskResources

import org.apache.commons.lang3.ClassUtils

Expand Down Expand Up @@ -621,6 +623,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
child: SparkPlan,
numOutputRows: SQLMetric,
dataSize: SQLMetric): BuildSideRelation = {
val useOffheapBroadcastBuildRelation =
GlutenConfig.getConf.enableBroadcastBuildRelationInOffheap
val serialized: Array[ColumnarBatchSerializeResult] = child
.executeColumnar()
.mapPartitions(itr => Iterator(BroadcastUtils.serializeStream(itr)))
Expand All @@ -633,7 +637,13 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
numOutputRows += serialized.map(_.getNumRows).sum
dataSize += rawSize
ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode)
if (useOffheapBroadcastBuildRelation) {
TaskResources.runUnsafe {
new UnsafeColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode)
}
} else {
ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode)
}
}

override def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.runtime.Runtimes
Expand All @@ -27,7 +28,8 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning}
import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode, LongHashedRelation}
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelation, HashedRelationBroadcastMode, LongHashedRelation}
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.task.TaskResources
Expand All @@ -45,7 +47,7 @@ object BroadcastUtils {
mode match {
case HashedRelationBroadcastMode(_, _) =>
// ColumnarBuildSideRelation to HashedRelation.
val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarBuildSideRelation]]
val fromBroadcast = from.asInstanceOf[Broadcast[BuildSideRelation]]
val fromRelation = fromBroadcast.value.asReadOnlyCopy()
var rowCount: Long = 0
val toRelation = TaskResources.runUnsafe {
Expand All @@ -60,7 +62,7 @@ object BroadcastUtils {
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
case IdentityBroadcastMode =>
// ColumnarBuildSideRelation to HashedRelation.
val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarBuildSideRelation]]
val fromBroadcast = from.asInstanceOf[Broadcast[BuildSideRelation]]
val fromRelation = fromBroadcast.value.asReadOnlyCopy()
val toRelation = TaskResources.runUnsafe {
val rowIterator = fn(fromRelation.deserialized)
Expand Down Expand Up @@ -91,6 +93,7 @@ object BroadcastUtils {
schema: StructType,
from: Broadcast[F],
fn: Iterator[InternalRow] => Iterator[ColumnarBatch]): Broadcast[T] = {
val useOffheapBuildRelation = GlutenConfig.getConf.enableBroadcastBuildRelationInOffheap
mode match {
case HashedRelationBroadcastMode(_, _) =>
// HashedRelation to ColumnarBuildSideRelation.
Expand All @@ -104,10 +107,17 @@ object BroadcastUtils {
case result: ColumnarBatchSerializeResult =>
Array(result.getSerialized)
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
if (useOffheapBuildRelation) {
new UnsafeColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
} else {
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
}
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
Expand All @@ -123,10 +133,17 @@ object BroadcastUtils {
case result: ColumnarBatchSerializeResult =>
Array(result.getSerialized)
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
if (useOffheapBuildRelation) {
new UnsafeColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
} else {
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
}
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.unsafe

import org.apache.spark.internal.Logging
import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.LongArray

import java.security.MessageDigest

/**
* Used to store broadcast variable off-heap memory for broadcast variable. The underlying data
* structure is a LongArray allocated in offheap memory.
*
* @param arraySize
* underlying array[array[byte]]'s length
* @param bytesBufferLengths
* underlying array[array[byte]] per bytesBuffer length
* @param totalBytes
* all bytesBuffer's length plus together
*/
// scalastyle:off no.finalize
case class UnsafeBytesBufferArray(
arraySize: Int,
bytesBufferLengths: Array[Int],
totalBytes: Long,
tmm: TaskMemoryManager)
extends MemoryConsumer(tmm, MemoryMode.OFF_HEAP)
with Logging {

/**
* A single array to store all bytesBufferArray's value, it's inited once when first time get
* accessed.
*/
private var longArray: LongArray = _

/** Index the start of each byteBuffer's offset to underlying LongArray's initial position. */
private val bytesBufferOffset = new Array[Int](arraySize)

{
assert(arraySize == bytesBufferLengths.length)
assert(totalBytes >= 0)
if (arraySize > 0) {
for (i <- 0 until arraySize) {
bytesBufferOffset(i) =
if (i > 0) bytesBufferLengths(i - 1) + bytesBufferOffset(i - 1) else 0
}
}
}

override def spill(l: Long, memoryConsumer: MemoryConsumer): Long = 0L

/**
* Put bytesBuffer at specified array index.
* @param index
* @param bytesBuffer
*/
def putBytesBuffer(index: Int, bytesBuffer: Array[Byte]): Unit = this.synchronized {
assert(index < arraySize)
assert(bytesBuffer.length == bytesBufferLengths(index))
// first to allocate underlying long array
if (null == longArray && index == 0) {
log.debug(s"allocate array $totalBytes, actual longArray size ${(totalBytes + 7) / 8}")
longArray = allocateArray((totalBytes + 7) / 8)
}
if (log.isDebugEnabled) {
log.debug(s"put bytesBuffer at index $index bytesBuffer's length is ${bytesBuffer.length}")
log.debug(
s"bytesBuffer at index $index " +
s"digest ${calculateMD5(bytesBuffer).mkString("Array(", ", ", ")")}")
}
Platform.copyMemory(
bytesBuffer,
Platform.BYTE_ARRAY_OFFSET,
longArray.getBaseObject,
longArray.getBaseOffset + bytesBufferOffset(index),
bytesBufferLengths(index))
}

/**
* Get bytesBuffer at specified index.
* @param index
* @return
*/
def getBytesBuffer(index: Int): Array[Byte] = {
assert(index < arraySize)
if (null == longArray) {
return new Array[Byte](0)
}
val bytes = new Array[Byte](bytesBufferLengths(index))
log.debug(s"get bytesBuffer at index $index bytesBuffer length ${bytes.length}")
Platform.copyMemory(
longArray.getBaseObject,
longArray.getBaseOffset + bytesBufferOffset(index),
bytes,
Platform.BYTE_ARRAY_OFFSET,
bytesBufferLengths(index))
if (log.isDebugEnabled) {
log.debug(
s"get bytesBuffer at index $index " +
s"digest ${calculateMD5(bytes).mkString("Array(", ", ", ")")}")
}
bytes
}

/**
* Get the bytesBuffer memory address and length at specified index, usually used when read memory
* direct from offheap.
*
* @param index
* @return
*/
def getBytesBufferOffsetAndLength(index: Int): (Long, Int) = {
assert(index < arraySize)
assert(longArray != null, "The broadcast data in offheap should not be null!")
val offset = longArray.getBaseOffset + bytesBufferOffset(index)
val length = bytesBufferLengths(index)
(offset, length)
}

/**
* It's needed once the broadcast variable is garbage collected. Since now, we don't have an
* elegant way to free the underlying memory in offheap.
*/
override def finalize(): Unit = {
try {
if (longArray != null) {
log.debug(s"BytesArrayInOffheap finalize $arraySize")
freeArray(longArray)
longArray = null
}
} finally {
super.finalize()
}
}

/**
* Used to debug input/output bytes.
*
* @param bytesBuffer
* @return
*/
private def calculateMD5(bytesBuffer: Array[Byte]): Array[Byte] = try {
val md = MessageDigest.getInstance("MD5")
md.digest(bytesBuffer)
} catch {
case e: Throwable =>
log.warn("error when calculateMD5", e)
new Array[Byte](0)
}
}
// scalastyle:on no.finalize
Loading
Loading