Skip to content

Commit

Permalink
Implement CHMergeTreeWriterInjects::createNativeWrite
Browse files Browse the repository at this point in the history
  • Loading branch information
baibaichen committed Oct 9, 2024
1 parent 2c35134 commit ef03755
Show file tree
Hide file tree
Showing 14 changed files with 114 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ public CHDatasourceJniWrapper(String filePath, WriteRel write) {
}

public CHDatasourceJniWrapper(
byte[] splitInfo, String taskId, String partition_dir, String bucket_dir, byte[] confArray) {
this.instance = createMergeTreeWriter(splitInfo, taskId, partition_dir, bucket_dir, confArray);
String taskId, String partition_dir, String bucket_dir, WriteRel write, byte[] confArray) {
this.instance =
createMergeTreeWriter(taskId, partition_dir, bucket_dir, write.toByteArray(), confArray);
}

public void write(long blockAddress) {
Expand All @@ -44,11 +45,11 @@ public String close() {
private native String close(long instanceId);

/// FileWriter
private native long createFilerWriter(String filePath, byte[] writeRelBytes);
private native long createFilerWriter(String filePath, byte[] writeRel);

/// MergeTreeWriter
private native long createMergeTreeWriter(
byte[] splitInfo, String taskId, String partition_dir, String bucket_dir, byte[] confArray);
String taskId, String partition_dir, String bucket_dir, byte[] writeRel, byte[] confArray);

public static native String nativeMergeMTParts(
byte[] splitInfo, String partition_dir, String bucket_dir);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ message Write {
string primary_key = 9;
string relative_path = 10;
string absolute_path = 11;

string storage_policy = 12;
}

Common common = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ trait ClickHouseTableV2Base extends TablePropertiesReader {

def orderByKey(): String = orderByKeyOption match {
case Some(keys) => keys.map(normalizeColName).mkString(",")
case None => "tuple()"
case None => StorageMeta.DEFAULT_ORDER_BY_KEY
}

def lowCardKey(): String = MergeTreeDeltaUtil.columnsToStr(lowCardKeyOption)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.delta._
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.execution.commands.GlutenCacheBase._
import org.apache.spark.sql.execution.datasources.clickhouse.{ClickhousePartSerializer, ExtensionTableBuilder}
import org.apache.spark.sql.execution.datasources.clickhouse.utils.MergeTreeDeltaUtil
import org.apache.spark.sql.execution.datasources.mergetree.StorageMeta
import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
import org.apache.spark.sql.types.{BooleanType, StringType}

Expand Down Expand Up @@ -176,7 +176,7 @@ case class GlutenCHCacheDataCommand(
onePart.tablePath,
pathToCache.toString,
snapshot.metadata.configuration
.getOrElse("orderByKey", MergeTreeDeltaUtil.DEFAULT_ORDER_BY_KEY),
.getOrElse("orderByKey", StorageMeta.DEFAULT_ORDER_BY_KEY),
snapshot.metadata.configuration.getOrElse("lowCardKey", ""),
snapshot.metadata.configuration.getOrElse("minmaxIndexKey", ""),
snapshot.metadata.configuration.getOrElse("bloomfilterIndexKey", ""),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@ import org.apache.gluten.expression.ConverterUtils
import org.apache.spark.sql.execution.datasources.clickhouse.utils.MergeTreeDeltaUtil
import org.apache.spark.sql.execution.datasources.mergetree.StorageMeta
import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
import org.apache.spark.sql.types.StructType

import com.fasterxml.jackson.databind.ObjectMapper
import io.substrait.proto.ReadRel

import java.net.URI
import java.util.{Map => jMap}

import scala.collection.JavaConverters._

case class ClickhousePartSerializer(
partList: Seq[String],
starts: Seq[Long],
Expand Down Expand Up @@ -81,37 +77,6 @@ object ClickhousePartSerializer {
}

object ClickhouseMetaSerializer {
def forWrite(
relativePath: String,
clickhouseTableConfigs: Map[String, String],
dataSchema: StructType): ReadRel.ExtensionTable = {

val orderByKey = clickhouseTableConfigs(StorageMeta.ORDER_BY_KEY)
val lowCardKey = clickhouseTableConfigs(StorageMeta.LOW_CARD_KEY)
val minmaxIndexKey = clickhouseTableConfigs(StorageMeta.MINMAX_INDEX_KEY)
val bfIndexKey = clickhouseTableConfigs(StorageMeta.BF_INDEX_KEY)
val setIndexKey = clickhouseTableConfigs(StorageMeta.SET_INDEX_KEY)
val primaryKey = clickhouseTableConfigs(StorageMeta.PRIMARY_KEY)

val result = apply(
clickhouseTableConfigs(StorageMeta.DB),
clickhouseTableConfigs(StorageMeta.TABLE),
clickhouseTableConfigs(StorageMeta.SNAPSHOT_ID),
relativePath,
"", // absolutePath
orderByKey,
lowCardKey,
minmaxIndexKey,
bfIndexKey,
setIndexKey,
primaryKey,
ClickhousePartSerializer.fromPartNames(Seq()),
ConverterUtils.convertNamedStructJson(dataSchema),
clickhouseTableConfigs.filter(_._1 == StorageMeta.POLICY).asJava
)
ExtensionTableNode.toProtobuf(result)

}
// scalastyle:off argcount
def apply1(
database: String,
Expand Down Expand Up @@ -193,29 +158,24 @@ object ClickhouseMetaSerializer {
.append(orderByKey)
.append("\n")

if (orderByKey.nonEmpty && !(orderByKey == "tuple()")) {
if (orderByKey.isEmpty || orderByKey == StorageMeta.DEFAULT_ORDER_BY_KEY) {
extensionTableStr.append("").append("\n")
} else {
extensionTableStr.append(primaryKey).append("\n")
}

extensionTableStr.append(lowCardKey).append("\n")
extensionTableStr.append(minmaxIndexKey).append("\n")
extensionTableStr.append(bfIndexKey).append("\n")
extensionTableStr.append(setIndexKey).append("\n")
extensionTableStr.append(normalizeRelativePath(relativePath)).append("\n")
extensionTableStr.append(StorageMeta.normalizeRelativePath(relativePath)).append("\n")
extensionTableStr.append(absolutePath).append("\n")
appendConfigs(extensionTableStr, clickhouseTableConfigs)
extensionTableStr.append(partSerializer())

extensionTableStr.toString()
}

private def normalizeRelativePath(relativePath: String): String = {
val table_uri = URI.create(relativePath)
if (table_uri.getPath.startsWith("/")) {
table_uri.getPath.substring(1)
} else table_uri.getPath
}

private def appendConfigs(
extensionTableStr: StringBuilder,
clickhouseTableConfigs: jMap[String, String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.datasources.clickhouse.utils

import org.apache.gluten.expression.ConverterUtils.normalizeColName

object MergeTreeDeltaUtil {
import org.apache.spark.sql.execution.datasources.mergetree.StorageMeta.DEFAULT_ORDER_BY_KEY

val DEFAULT_ORDER_BY_KEY = "tuple()"
object MergeTreeDeltaUtil {

def genOrderByAndPrimaryKeyStr(
orderByKeyOption: Option[Seq[String]],
Expand All @@ -36,10 +36,7 @@ object MergeTreeDeltaUtil {
(orderByKey, primaryKey)
}

def columnsToStr(option: Option[Seq[String]]): String = option match {
case Some(keys) => keys.map(normalizeColName).mkString(",")
case None => ""
}
def columnsToStr(option: Option[Seq[String]]): String = option.map(columnsToStr).getOrElse("")

def columnsToStr(keys: Seq[String]): String = {
keys.map(normalizeColName).mkString(",")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ import org.apache.spark.sql.execution.datasources.clickhouse.utils.MergeTreeDelt

import org.apache.hadoop.fs.Path

import java.net.URI

/** Reserved table property for MergeTree table. */
object StorageMeta {

// Storage properties
val DEFAULT_PATH_BASED_DATABASE: String = "clickhouse_db"
val DEFAULT_CREATE_TABLE_DATABASE: String = "default"
val DEFAULT_ORDER_BY_KEY = "tuple()"
val DB: String = "storage_db"
val TABLE: String = "storage_table"
val SNAPSHOT_ID: String = "storage_snapshot_id"
Expand Down Expand Up @@ -60,6 +63,13 @@ object StorageMeta {
private def withMoreOptions(metadata: Metadata, newOptions: Seq[(String, String)]): Metadata = {
metadata.copy(configuration = metadata.configuration ++ newOptions)
}

def normalizeRelativePath(relativePath: String): String = {
val table_uri = URI.create(relativePath)
if (table_uri.getPath.startsWith("/")) {
table_uri.getPath.substring(1)
} else table_uri.getPath
}
}

trait WriteConfiguration {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v1
import org.apache.gluten.execution.datasource.GlutenRowSplitter
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.memory.CHThreadGroup
import org.apache.gluten.utils.SubstraitUtil.createNameStructBuilder
import org.apache.gluten.utils.SubstraitUtil
import org.apache.gluten.vectorized.CHColumnVector

import org.apache.spark.sql.SparkSession
Expand All @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.StructType

import com.google.protobuf.Any
import io.substrait.proto
import io.substrait.proto.{AdvancedExtension, NamedObjectWrite}
import io.substrait.proto.{AdvancedExtension, NamedObjectWrite, NamedStruct}
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.mapreduce.TaskAttemptContext

Expand All @@ -40,26 +40,35 @@ import scala.collection.JavaConverters.seqAsJavaListConverter

trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {

private def createWriteRel(dataSchema: StructType): proto.WriteRel = {
// TODO: move to SubstraitUtil
private def toNameStruct(dataSchema: StructType): NamedStruct = {
SubstraitUtil
.createNameStructBuilder(
ConverterUtils.collectAttributeTypeNodes(dataSchema),
dataSchema.fieldNames.map(ConverterUtils.normalizeColName).toSeq.asJava,
ju.Collections.emptyList()
)
.build()
}
def createWriteRel(
outputPath: String,
dataSchema: StructType,
context: TaskAttemptContext): proto.WriteRel = {
proto.WriteRel
.newBuilder()
.setTableSchema(
createNameStructBuilder(
ConverterUtils.collectAttributeTypeNodes(dataSchema),
dataSchema.fieldNames.toSeq.asJava,
ju.Collections.emptyList()).build())
.setTableSchema(toNameStruct(dataSchema))
.setNamedTable(
NamedObjectWrite.newBuilder
.setAdvancedExtension(
AdvancedExtension
.newBuilder()
.setOptimization(Any.pack(createNativeWrite()))
.setOptimization(Any.pack(createNativeWrite(outputPath, context)))
.build())
.build())
.build()
}

def createNativeWrite(): Write
def createNativeWrite(outputPath: String, context: TaskAttemptContext): Write

override def createOutputWriter(
outputPath: String,
Expand All @@ -69,7 +78,7 @@ trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {
CHThreadGroup.registerNewThreadGroup()

val datasourceJniWrapper =
new CHDatasourceJniWrapper(outputPath, createWriteRel(dataSchema))
new CHDatasourceJniWrapper(outputPath, createWriteRel(outputPath, dataSchema, context))

new OutputWriter {
override def write(row: InternalRow): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,31 +62,48 @@ class CHMergeTreeWriterInjects extends CHFormatWriterInjects {
options.asJava
}

override def createNativeWrite(): Write = {
throw new UnsupportedOperationException(
"createNativeWrite is not supported in CHMergeTreeWriterInjects")
override def createNativeWrite(outputPath: String, context: TaskAttemptContext): Write = {
val conf = HadoopConfReader(context.getConfiguration).writeConfiguration
Write
.newBuilder()
.setCommon(Write.Common.newBuilder().setFormat(formatName).build())
.setMergetree(
Write.MergeTreeWrite
.newBuilder()
.setDatabase(conf(StorageMeta.DB))
.setTable(conf(StorageMeta.TABLE))
.setSnapshotId(conf(StorageMeta.SNAPSHOT_ID))
.setOrderByKey(conf(StorageMeta.ORDER_BY_KEY))
.setLowCardKey(conf(StorageMeta.LOW_CARD_KEY))
.setMinmaxIndexKey(conf(StorageMeta.MINMAX_INDEX_KEY))
.setBfIndexKey(conf(StorageMeta.BF_INDEX_KEY))
.setSetIndexKey(conf(StorageMeta.SET_INDEX_KEY))
.setPrimaryKey(conf(StorageMeta.PRIMARY_KEY))
.setRelativePath(StorageMeta.normalizeRelativePath(outputPath))
.setAbsolutePath("")
.setStoragePolicy(conf(StorageMeta.POLICY))
.build())
.build()
}

override def createOutputWriter(
path: String,
outputPath: String,
dataSchema: StructType,
context: TaskAttemptContext,
nativeConf: ju.Map[String, String]): OutputWriter = {

val storage = HadoopConfReader(context.getConfiguration)
val database = storage.writeConfiguration(StorageMeta.DB)
val tableName = storage.writeConfiguration(StorageMeta.TABLE)
val extensionTable =
ClickhouseMetaSerializer.forWrite(path, storage.writeConfiguration, dataSchema)

val datasourceJniWrapper = new CHDatasourceJniWrapper(
extensionTable.toByteArray,
context.getTaskAttemptID.getTaskID.getId.toString,
context.getConfiguration.get("mapreduce.task.gluten.mergetree.partition.dir"),
context.getConfiguration.get("mapreduce.task.gluten.mergetree.bucketid.str"),
createWriteRel(outputPath, dataSchema, context),
ConfigUtil.serialize(nativeConf)
)
new MergeTreeOutputWriter(datasourceJniWrapper, database, tableName, path)
new MergeTreeOutputWriter(datasourceJniWrapper, database, tableName, outputPath)
}

override val formatName: String = "mergetree"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.execution.datasources.v1

import org.apache.hadoop.mapreduce.TaskAttemptContext

import java.{util => ju}

class CHOrcWriterInjects extends CHFormatWriterInjects {
Expand All @@ -28,7 +30,7 @@ class CHOrcWriterInjects extends CHFormatWriterInjects {
ju.Collections.emptyMap()
}

override def createNativeWrite(): Write = Write
override def createNativeWrite(outputPath: String, context: TaskAttemptContext): Write = Write
.newBuilder()
.setCommon(Write.Common.newBuilder().setFormat(formatName).build())
.setOrc(Write.OrcWrite.newBuilder().build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import org.apache.gluten.GlutenConfig

import org.apache.spark.sql.internal.SQLConf

import org.apache.hadoop.mapreduce.TaskAttemptContext

import java.{util => ju}

class CHParquetWriterInjects extends CHFormatWriterInjects {
Expand All @@ -41,7 +43,7 @@ class CHParquetWriterInjects extends CHFormatWriterInjects {
sparkOptions
}

override def createNativeWrite(): Write = Write
override def createNativeWrite(outputPath: String, context: TaskAttemptContext): Write = Write
.newBuilder()
.setCommon(Write.Common.newBuilder().setFormat(formatName).build())
.setParquet(Write.ParquetWrite.newBuilder().build())
Expand Down
Loading

0 comments on commit ef03755

Please sign in to comment.