Skip to content

Commit

Permalink
refactor CHDatasourceJniWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
baibaichen committed Oct 8, 2024
1 parent 6aa83c8 commit def1e47
Show file tree
Hide file tree
Showing 13 changed files with 194 additions and 123 deletions.
22 changes: 22 additions & 0 deletions backends-clickhouse/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,28 @@
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<!-- compile proto buffer files using copied protoc binary -->
<plugin>
<groupId>org.xolstice.maven.plugins</groupId>
<artifactId>protobuf-maven-plugin</artifactId>
<executions>
<execution>
<id>compile-gluten-proto</id>
<phase>generate-sources</phase>
<goals>
<goal>compile</goal>
<goal>test-compile</goal>
</goals>
<configuration>
<protocArtifact>
com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
</protocArtifact>
<protoSourceRoot>src/main/resources/org/apache/spark/sql/execution/datasources/v1</protoSourceRoot>
<clearOutputDirectory>false</clearOutputDirectory>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-resources-plugin</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
import org.apache.hadoop.mapreduce.{TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl

import java.util.{Date, UUID}
import java.util.Date
import scala.collection.mutable.ArrayBuffer

object OptimizeTableCommandOverwrites extends Logging {
Expand Down Expand Up @@ -95,8 +95,6 @@ object OptimizeTableCommandOverwrites extends Logging {
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {

val uuid = UUID.randomUUID.toString

val planWithSplitInfo = CHMergeTreeWriterInjects.genMergeTreeWriteRel(
description.path,
description.database,
Expand All @@ -115,12 +113,9 @@ object OptimizeTableCommandOverwrites extends Logging {
description.tableSchema.toAttributes
)

val datasourceJniWrapper = new CHDatasourceJniWrapper()
val returnedMetrics =
datasourceJniWrapper.nativeMergeMTParts(
CHDatasourceJniWrapper.nativeMergeMTParts(
planWithSplitInfo.splitInfo,
uuid,
taskId.getId.toString,
description.partitionDir.getOrElse(""),
description.bucketDir.getOrElse("")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
import org.apache.hadoop.mapreduce.{TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl

import java.util.{Date, UUID}
import java.util.Date
import scala.collection.mutable.ArrayBuffer

object OptimizeTableCommandOverwrites extends Logging {
Expand Down Expand Up @@ -95,8 +95,6 @@ object OptimizeTableCommandOverwrites extends Logging {
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {

val uuid = UUID.randomUUID.toString

val planWithSplitInfo = CHMergeTreeWriterInjects.genMergeTreeWriteRel(
description.path,
description.database,
Expand All @@ -115,12 +113,9 @@ object OptimizeTableCommandOverwrites extends Logging {
description.tableSchema.toAttributes
)

val datasourceJniWrapper = new CHDatasourceJniWrapper()
val returnedMetrics =
datasourceJniWrapper.nativeMergeMTParts(
CHDatasourceJniWrapper.nativeMergeMTParts(
planWithSplitInfo.splitInfo,
uuid,
taskId.getId.toString,
description.partitionDir.getOrElse(""),
description.bucketDir.getOrElse("")
)
Expand Down Expand Up @@ -169,7 +164,7 @@ object OptimizeTableCommandOverwrites extends Logging {
bucketNum: String,
bin: Seq[AddFile],
maxFileSize: Long): Seq[FileAction] = {
val tableV2 = ClickHouseTableV2.getTable(txn.deltaLog);
val tableV2 = ClickHouseTableV2.getTable(txn.deltaLog)

val sparkSession = SparkSession.getActiveSession.get

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.hadoop.fs.{FileAlreadyExistsException, Path}
import org.apache.hadoop.mapreduce.{TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl

import java.util.{Date, UUID}
import java.util.Date
import scala.collection.mutable.ArrayBuffer

object OptimizeTableCommandOverwrites extends Logging {
Expand Down Expand Up @@ -97,8 +97,6 @@ object OptimizeTableCommandOverwrites extends Logging {
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {

val uuid = UUID.randomUUID.toString

val planWithSplitInfo = CHMergeTreeWriterInjects.genMergeTreeWriteRel(
description.path,
description.database,
Expand All @@ -117,12 +115,9 @@ object OptimizeTableCommandOverwrites extends Logging {
DataTypeUtils.toAttributes(description.tableSchema)
)

val datasourceJniWrapper = new CHDatasourceJniWrapper()
val returnedMetrics =
datasourceJniWrapper.nativeMergeMTParts(
CHDatasourceJniWrapper.nativeMergeMTParts(
planWithSplitInfo.splitInfo,
uuid,
taskId.getId.toString,
description.partitionDir.getOrElse(""),
description.bucketDir.getOrElse("")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,42 @@
*/
package org.apache.spark.sql.execution.datasources;

import io.substrait.proto.WriteRel;

public class CHDatasourceJniWrapper {

public native void write(long instanceId, long blockAddress);
private final long instance;

public CHDatasourceJniWrapper(String filePath, WriteRel write) {
this.instance = createFilerWriter(filePath, write.toByteArray());
}

public CHDatasourceJniWrapper(
byte[] splitInfo, String taskId, String partition_dir, String bucket_dir, byte[] confArray) {
this.instance = createMergeTreeWriter(splitInfo, taskId, partition_dir, bucket_dir, confArray);
}

public void write(long blockAddress) {
write(instance, blockAddress);
}

public native String close(long instanceId);
public String close() {
return close(instance);
}

private native void write(long instanceId, long blockAddress);

private native String close(long instanceId);

/// FileWriter
public native long createFilerWriter(String filePath, byte[] preferredSchema, String formatHint);
private native long createFilerWriter(String filePath, byte[] writeRelBytes);

/// MergeTreeWriter
public native long createMergeTreeWriter(
byte[] splitInfo,
String uuid,
String taskId,
String partition_dir,
String bucket_dir,
byte[] confArray);

public native String nativeMergeMTParts(
byte[] splitInfo, String uuid, String taskId, String partition_dir, String bucket_dir);
private native long createMergeTreeWriter(
byte[] splitInfo, String taskId, String partition_dir, String bucket_dir, byte[] confArray);

public static native String nativeMergeMTParts(
byte[] splitInfo, String partition_dir, String bucket_dir);

public static native String filterRangesOnDriver(byte[] plan, byte[] read);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";

package local_engine;

option java_package = "org.apache.spark.sql.execution.datasources.v1";
option java_multiple_files = true;

message Write {
message Common {
string format = 1;
}
message ParquetWrite{}
message OrcWrite{}
message MergeTreeWrite{
string database = 1;
string table = 2;
string snapshot_id = 3;
string order_by_key = 4;
string low_card_key = 5;
string minmax_index_key = 6;
string bf_index_key = 7;
string set_index_key = 8;
string primary_key = 9;
string relative_path = 10;
string absolute_path = 11;
}

Common common = 1;
oneof file_format {
ParquetWrite parquet = 2;
OrcWrite orc = 3;
MergeTreeWrite mergetree = 4;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.v1
import org.apache.gluten.execution.datasource.GlutenRowSplitter
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.utils.SubstraitUtil.createNameStructBuilder
import org.apache.gluten.vectorized.CHColumnVector

import org.apache.spark.sql.SparkSession
Expand All @@ -27,32 +28,48 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.types.StructType

import io.substrait.proto.{NamedStruct, Type}
import com.google.protobuf.Any
import io.substrait.proto
import io.substrait.proto.{AdvancedExtension, NamedObjectWrite}
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.mapreduce.TaskAttemptContext

import java.{util => ju}

import scala.collection.JavaConverters.seqAsJavaListConverter

trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {

private def createWriteRel(dataSchema: StructType): proto.WriteRel = {
proto.WriteRel
.newBuilder()
.setTableSchema(
createNameStructBuilder(
ConverterUtils.collectAttributeTypeNodes(dataSchema),
dataSchema.fieldNames.toSeq.asJava,
ju.Collections.emptyList()).build())
.setNamedTable(
NamedObjectWrite.newBuilder
.setAdvancedExtension(
AdvancedExtension
.newBuilder()
.setOptimization(Any.pack(createNativeWrite()))
.build())
.build())
.build()
}

def createNativeWrite(): Write

override def createOutputWriter(
path: String,
outputPath: String,
dataSchema: StructType,
context: TaskAttemptContext,
nativeConf: java.util.Map[String, String]): OutputWriter = {
val originPath = path
val datasourceJniWrapper = new CHDatasourceJniWrapper()
nativeConf: ju.Map[String, String]): OutputWriter = {
CHThreadGroup.registerNewThreadGroup()

val namedStructBuilder = NamedStruct.newBuilder
val structBuilder = Type.Struct.newBuilder
for (field <- dataSchema.fields) {
namedStructBuilder.addNames(field.name)
structBuilder.addTypes(ConverterUtils.getTypeNode(field.dataType, field.nullable).toProtobuf)
}
namedStructBuilder.setStruct(structBuilder.build)
val namedStruct = namedStructBuilder.build

val instance =
datasourceJniWrapper.createFilerWriter(path, namedStruct.toByteArray, formatName)
val datasourceJniWrapper =
new CHDatasourceJniWrapper(outputPath, createWriteRel(dataSchema))

new OutputWriter {
override def write(row: InternalRow): Unit = {
Expand All @@ -61,17 +78,17 @@ trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {

if (nextBatch.numRows > 0) {
val col = nextBatch.column(0).asInstanceOf[CHColumnVector]
datasourceJniWrapper.write(instance, col.getBlockAddress)
datasourceJniWrapper.write(col.getBlockAddress)
} // else just ignore this empty block
}

override def close(): Unit = {
datasourceJniWrapper.close(instance)
datasourceJniWrapper.close()
}

// Do NOT add override keyword for compatibility on spark 3.1.
def path(): String = {
originPath
outputPath
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.NamedStruct
import org.apache.hadoop.mapreduce.TaskAttemptContext

import java.util.{Map => JMap, UUID}
import java.{util => ju}

import scala.collection.JavaConverters._

Expand All @@ -45,38 +45,40 @@ class CHMergeTreeWriterInjects extends CHFormatWriterInjects {

override def nativeConf(
options: Map[String, String],
compressionCodec: String): JMap[String, String] = {
compressionCodec: String): ju.Map[String, String] = {
options.asJava
}

override def createNativeWrite(): Write = {
throw new UnsupportedOperationException(
"createNativeWrite is not supported in CHMergeTreeWriterInjects")
}

override def createOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
nativeConf: JMap[String, String]): OutputWriter = null
nativeConf: ju.Map[String, String]): OutputWriter = null

override val formatName: String = "mergetree"

def createOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
nativeConf: JMap[String, String],
nativeConf: ju.Map[String, String],
database: String,
tableName: String,
splitInfo: Array[Byte]): OutputWriter = {
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val instance =
datasourceJniWrapper.createMergeTreeWriter(
splitInfo,
UUID.randomUUID.toString,
context.getTaskAttemptID.getTaskID.getId.toString,
context.getConfiguration.get("mapreduce.task.gluten.mergetree.partition.dir"),
context.getConfiguration.get("mapreduce.task.gluten.mergetree.bucketid.str"),
ConfigUtil.serialize(nativeConf)
)

new MergeTreeOutputWriter(database, tableName, datasourceJniWrapper, instance, path)
val datasourceJniWrapper = new CHDatasourceJniWrapper(
splitInfo,
context.getTaskAttemptID.getTaskID.getId.toString,
context.getConfiguration.get("mapreduce.task.gluten.mergetree.partition.dir"),
context.getConfiguration.get("mapreduce.task.gluten.mergetree.bucketid.str"),
ConfigUtil.serialize(nativeConf)
)

new MergeTreeOutputWriter(datasourceJniWrapper, database, tableName, path)
}
}

Expand Down
Loading

0 comments on commit def1e47

Please sign in to comment.