Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-6920][CORE][VL] New APIs and refactors to allow different backends / components to be registered and used #8143

Merged
merged 29 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.gluten.vectorized;

import org.apache.gluten.GlutenConfig;
import org.apache.gluten.backend.Backend;
import org.apache.gluten.backendsapi.BackendsApiManager;
import org.apache.gluten.execution.ColumnarNativeIterator;
import org.apache.gluten.memory.CHThreadGroup;
Expand All @@ -35,11 +34,12 @@ private CHNativeExpressionEvaluator() {}
// Used to initialize the native computing.
public static void initNative(scala.collection.Map<String, String> conf) {
Map<String, String> nativeConfMap =
GlutenConfig.getNativeBackendConf(Backend.get().name(), conf);
GlutenConfig.getNativeBackendConf(BackendsApiManager.getBackendName(), conf);

// Get the customer config from SparkConf for each backend
BackendsApiManager.getTransformerApiInstance()
.postProcessNativeConfig(nativeConfMap, GlutenConfig.prefixOf(Backend.get().name()));
.postProcessNativeConfig(
nativeConfMap, GlutenConfig.prefixOf(BackendsApiManager.getBackendName()));

nativeInitNative(ConfigUtil.serialize(nativeConfMap));
}
Expand All @@ -54,7 +54,8 @@ public static boolean doValidate(byte[] subPlan) {
}

private static Map<String, String> getNativeBackendConf() {
return GlutenConfig.getNativeBackendConf(Backend.get().name(), SQLConf.get().getAllConfs());
return GlutenConfig.getNativeBackendConf(
BackendsApiManager.getBackendName(), SQLConf.get().getAllConfs());
}

// Used by WholeStageTransform to create the native computing pipeline and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.gluten.backendsapi.clickhouse

import org.apache.gluten.GlutenBuildInfo._
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backend.Backend
import org.apache.gluten.backend.Component.BuildInfo
import org.apache.gluten.backendsapi._
import org.apache.gluten.columnarbatch.CHBatch
import org.apache.gluten.execution.WriteFilesExecTransformer
Expand Down Expand Up @@ -49,8 +49,8 @@ import scala.util.control.Breaks.{break, breakable}
class CHBackend extends SubstraitBackend {
import CHBackend._
override def name(): String = CHConf.BACKEND_NAME
override def buildInfo(): Backend.BuildInfo =
Backend.BuildInfo("ClickHouse", CH_BRANCH, CH_COMMIT, "UNKNOWN")
override def buildInfo(): BuildInfo =
BuildInfo("ClickHouse", CH_BRANCH, CH_COMMIT, "UNKNOWN")
override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def iteratorApi(): IteratorApi = new CHIteratorApi
override def sparkPlanExecApi(): SparkPlanExecApi = new CHSparkPlanExecApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.columnarbatch;

import org.apache.gluten.backendsapi.BackendsApiManager;
import org.apache.gluten.runtime.Runtime;
import org.apache.gluten.runtime.Runtimes;

Expand Down Expand Up @@ -59,8 +60,10 @@ public static ColumnarBatch toVeloxBatch(ColumnarBatch input) {
return input;
}
Preconditions.checkArgument(!isVeloxBatch(input));
final Runtime runtime = Runtimes.contextInstance("VeloxColumnarBatches#toVeloxBatch");
final long handle = ColumnarBatches.getNativeHandle(input);
final Runtime runtime =
Runtimes.contextInstance(
BackendsApiManager.getBackendName(), "VeloxColumnarBatches#toVeloxBatch");
final long handle = ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName(), input);
final long outHandle = VeloxColumnarBatchJniWrapper.create(runtime).from(handle);
final ColumnarBatch output = ColumnarBatches.create(outHandle);

Expand Down Expand Up @@ -88,9 +91,13 @@ public static ColumnarBatch toVeloxBatch(ColumnarBatch input) {
* Otherwise {@link UnsupportedOperationException} will be thrown.
*/
public static ColumnarBatch compose(ColumnarBatch... batches) {
final Runtime runtime = Runtimes.contextInstance("VeloxColumnarBatches#compose");
final Runtime runtime =
Runtimes.contextInstance(
BackendsApiManager.getBackendName(), "VeloxColumnarBatches#compose");
final long[] handles =
Arrays.stream(batches).mapToLong(ColumnarBatches::getNativeHandle).toArray();
Arrays.stream(batches)
.mapToLong(b -> ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName(), b))
.toArray();
final long handle = VeloxColumnarBatchJniWrapper.create(runtime).compose(handles);
return ColumnarBatches.create(handle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.metrics;

import org.apache.gluten.backendsapi.BackendsApiManager;
import org.apache.gluten.runtime.Runtime;
import org.apache.gluten.runtime.RuntimeAware;
import org.apache.gluten.runtime.Runtimes;
Expand All @@ -29,7 +30,8 @@ private IteratorMetricsJniWrapper(Runtime runtime) {
}

public static IteratorMetricsJniWrapper create() {
final Runtime runtime = Runtimes.contextInstance("IteratorMetrics");
final Runtime runtime =
Runtimes.contextInstance(BackendsApiManager.getBackendName(), "IteratorMetrics");
return new IteratorMetricsJniWrapper(runtime);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.utils;

import org.apache.gluten.backendsapi.BackendsApiManager;
import org.apache.gluten.runtime.Runtime;
import org.apache.gluten.runtime.Runtimes;
import org.apache.gluten.vectorized.ColumnarBatchInIterator;
Expand All @@ -28,10 +29,14 @@
public final class VeloxBatchResizer {
public static ColumnarBatchOutIterator create(
int minOutputBatchSize, int maxOutputBatchSize, Iterator<ColumnarBatch> in) {
final Runtime runtime = Runtimes.contextInstance("VeloxBatchResizer");
final Runtime runtime =
Runtimes.contextInstance(BackendsApiManager.getBackendName(), "VeloxBatchResizer");
long outHandle =
VeloxBatchResizerJniWrapper.create(runtime)
.create(minOutputBatchSize, maxOutputBatchSize, new ColumnarBatchInIterator(in));
.create(
minOutputBatchSize,
maxOutputBatchSize,
new ColumnarBatchInIterator(BackendsApiManager.getBackendName(), in));
return new ColumnarBatchOutIterator(runtime, outHandle);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.utils;

import org.apache.gluten.backendsapi.BackendsApiManager;
import org.apache.gluten.runtime.Runtimes;

import org.apache.commons.io.IOUtils;
Expand All @@ -30,7 +31,8 @@

public class VeloxBloomFilter extends BloomFilter {
private final VeloxBloomFilterJniWrapper jni =
VeloxBloomFilterJniWrapper.create(Runtimes.contextInstance("VeloxBloomFilter"));
VeloxBloomFilterJniWrapper.create(
Runtimes.contextInstance(BackendsApiManager.getBackendName(), "VeloxBloomFilter"));
private final long handle;

private VeloxBloomFilter(byte[] data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.gluten.backendsapi.velox

import org.apache.gluten.GlutenBuildInfo._
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backend.Backend
import org.apache.gluten.backend.Component.BuildInfo
import org.apache.gluten.backendsapi._
import org.apache.gluten.columnarbatch.VeloxBatch
import org.apache.gluten.exception.GlutenNotSupportException
Expand Down Expand Up @@ -52,9 +52,10 @@ import scala.util.control.Breaks.breakable

class VeloxBackend extends SubstraitBackend {
import VeloxBackend._

override def name(): String = VeloxBackend.BACKEND_NAME
override def buildInfo(): Backend.BuildInfo =
Backend.BuildInfo("Velox", VELOX_BRANCH, VELOX_REVISION, VELOX_REVISION_TIME)
override def buildInfo(): BuildInfo =
BuildInfo("Velox", VELOX_BRANCH, VELOX_REVISION, VELOX_REVISION_TIME)
override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def iteratorApi(): IteratorApi = new VeloxIteratorApi
override def sparkPlanExecApi(): SparkPlanExecApi = new VeloxSparkPlanExecApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ class VeloxIteratorApi extends IteratorApi with Logging {

val columnarNativeIterators =
new JArrayList[ColumnarBatchInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
iter => new ColumnarBatchInIterator(BackendsApiManager.getBackendName, iter.asJava)
}.asJava)
val transKernel = NativePlanEvaluator.create()
val transKernel = NativePlanEvaluator.create(BackendsApiManager.getBackendName)

val splitInfoByteArray = inputPartition
.asInstanceOf[GlutenPartition]
Expand Down Expand Up @@ -235,10 +235,10 @@ class VeloxIteratorApi extends IteratorApi with Logging {

ExecutorManager.tryTaskSet(numaBindingInfo)

val transKernel = NativePlanEvaluator.create()
val transKernel = NativePlanEvaluator.create(BackendsApiManager.getBackendName)
val columnarNativeIterator =
new JArrayList[ColumnarBatchInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
iter => new ColumnarBatchInIterator(BackendsApiManager.getBackendName, iter.asJava)
}.asJava)
val spillDirPath = SparkDirectoryUtil
.get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class VeloxListenerApi extends ListenerApi with Logging {
if (isDriver && !inLocalMode(conf)) {
parsed += (GlutenConfig.COLUMNAR_VELOX_CACHE_ENABLED.key -> "false")
}
NativeBackendInitializer.initializeBackend(parsed)
NativeBackendInitializer.forBackend(VeloxBackend.BACKEND_NAME).initialize(parsed)

// Inject backend-specific implementations to override spark classes.
GlutenFormatFactory.register(new VeloxParquetWriterInjects)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.backendsapi.velox

import org.apache.gluten.backendsapi.TransformerApi
import org.apache.gluten.backendsapi.{BackendsApiManager, TransformerApi}
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.runtime.Runtimes
Expand Down Expand Up @@ -87,7 +87,9 @@ class VeloxTransformerApi extends TransformerApi with Logging {
override def getNativePlanString(substraitPlan: Array[Byte], details: Boolean): String = {
TaskResources.runUnsafe {
val jniWrapper = PlanEvaluatorJniWrapper.create(
Runtimes.contextInstance("VeloxTransformerApi#getNativePlanString"))
Runtimes.contextInstance(
BackendsApiManager.getBackendName,
"VeloxTransformerApi#getNativePlanString"))
jniWrapper.nativePlanString(substraitPlan, details)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.backendsapi.velox

import org.apache.gluten.backendsapi.ValidatorApi
import org.apache.gluten.backendsapi.{BackendsApiManager, ValidatorApi}
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.substrait.plan.PlanNode
import org.apache.gluten.validate.NativePlanValidationInfo
Expand All @@ -38,7 +38,7 @@ class VeloxValidatorApi extends ValidatorApi {

override def doNativeValidateWithFailureReason(plan: PlanNode): ValidationResult = {
TaskResources.runUnsafe {
val validator = NativePlanEvaluator.create()
val validator = NativePlanEvaluator.create(BackendsApiManager.getBackendName)
asValidationResult(validator.doNativeValidateWithFailureReason(plan.toProtobuf.toByteArray))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.datasource

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.utils.ArrowAbiUtil
Expand All @@ -38,7 +39,7 @@ object VeloxDataSourceUtil {

def readSchema(file: FileStatus): Option[StructType] = {
val allocator = ArrowBufferAllocators.contextInstance()
val runtime = Runtimes.contextInstance("VeloxWriter")
val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "VeloxWriter")
val datasourceJniWrapper = VeloxDataSourceJniWrapper.create(runtime)
val dsHandle =
datasourceJniWrapper.init(file.getPath.toString, -1, new util.HashMap[String, String]())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
Iterator.empty
} else {
val start = System.currentTimeMillis()
val childData = ColumnarBatches.select(batch, projectIndexInChild.toArray)
val childData = ColumnarBatches
.select(BackendsApiManager.getBackendName, batch, projectIndexInChild.toArray)
val projectedBatch = getProjectedBatchArrow(childData, c2a, a2c)

val batchIterator = projectedBatch.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
Expand Down Expand Up @@ -121,7 +122,7 @@ object RowToVeloxColumnarExec {

val arrowSchema =
SparkArrowUtil.toArrowSchema(schema, SQLConf.get.sessionLocalTimeZone)
val runtime = Runtimes.contextInstance("RowToColumnar")
val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "RowToColumnar")
val jniWrapper = NativeRowToColumnarJniWrapper.create(runtime)
val arrowAllocator = ArrowBufferAllocators.contextInstance()
val cSchema = ArrowSchema.allocateNew(arrowAllocator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches}
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.extension.ValidationResult
Expand Down Expand Up @@ -122,7 +123,7 @@ object VeloxColumnarToRowExec {
return Iterator.empty
}

val runtime = Runtimes.contextInstance("ColumnarToRow")
val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "ColumnarToRow")
// TODO: Pass the jni jniWrapper and arrowSchema and serializeSchema method by broadcast.
val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime)
val c2rId = jniWrapper.nativeColumnarToRowInit()
Expand Down Expand Up @@ -156,7 +157,7 @@ object VeloxColumnarToRowExec {
val cols = batch.numCols()
val rows = batch.numRows()
val beforeConvert = System.currentTimeMillis()
val batchHandle = ColumnarBatches.getNativeHandle(batch)
val batchHandle = ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, batch)
var info =
jniWrapper.nativeColumnarToRowConvert(c2rId, batchHandle, 0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.vectorized

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.iterator.ClosableIterator
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
Expand Down Expand Up @@ -99,7 +100,7 @@ private class ColumnarBatchSerializerInstance(
GlutenConfig.getConf.columnarShuffleCodecBackend.orNull
val batchSize = GlutenConfig.getConf.maxBatchSize
val bufferSize = GlutenConfig.getConf.columnarShuffleReaderBufferSize
val runtime = Runtimes.contextInstance("ShuffleReader")
val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "ShuffleReader")
val jniWrapper = ShuffleReaderJniWrapper.create(runtime)
val shuffleReaderHandle = jniWrapper.make(
cSchema.memoryAddress(),
Expand Down Expand Up @@ -135,7 +136,8 @@ private class ColumnarBatchSerializerInstance(
extends DeserializationStream
with TaskResource {
private val byteIn: JniByteInputStream = JniByteInputStreams.create(in)
private val runtime = Runtimes.contextInstance("ShuffleReader")
private val runtime =
Runtimes.contextInstance(BackendsApiManager.getBackendName, "ShuffleReader")
private val wrappedOut: ClosableIterator = new ColumnarBatchOutIterator(
runtime,
ShuffleReaderJniWrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.shuffle

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers}
import org.apache.gluten.runtime.Runtimes
Expand Down Expand Up @@ -99,7 +100,7 @@ class ColumnarShuffleWriter[K, V](

private val reallocThreshold = GlutenConfig.getConf.columnarShuffleReallocThreshold

private val runtime = Runtimes.contextInstance("ShuffleWriter")
private val runtime = Runtimes.contextInstance(BackendsApiManager.getBackendName, "ShuffleWriter")

private val jniWrapper = ShuffleWriterJniWrapper.create(runtime)

Expand Down Expand Up @@ -135,7 +136,7 @@ class ColumnarShuffleWriter[K, V](
logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols")
} else {
val rows = cb.numRows()
val handle = ColumnarBatches.getNativeHandle(cb)
val handle = ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, cb)
if (nativeShuffleWriter == -1L) {
nativeShuffleWriter = jniWrapper.make(
dep.nativePartitioning.getShortName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.sql.shims.SparkShimLoader
Expand Down Expand Up @@ -152,11 +153,14 @@ object BroadcastUtils {
if (filtered.isEmpty) {
return ColumnarBatchSerializeResult.EMPTY
}
val handleArray = filtered.map(ColumnarBatches.getNativeHandle)
val handleArray =
filtered.map(b => ColumnarBatches.getNativeHandle(BackendsApiManager.getBackendName, b))
val serializeResult =
try {
ColumnarBatchSerializerJniWrapper
.create(Runtimes.contextInstance("BroadcastUtils#serializeStream"))
.create(
Runtimes
.contextInstance(BackendsApiManager.getBackendName, "BroadcastUtils#serializeStream"))
.serialize(handleArray)
} finally {
filtered.foreach(ColumnarBatches.release)
Expand Down
Loading
Loading