Skip to content

Commit

Permalink
Add NativeOutputWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
baibaichen committed Oct 8, 2024
1 parent 3dceeb8 commit 6aa83c8
Show file tree
Hide file tree
Showing 26 changed files with 824 additions and 411 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ object OptimizeTableCommandOverwrites extends Logging {
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val returnedMetrics =
datasourceJniWrapper.nativeMergeMTParts(
planWithSplitInfo.plan,
planWithSplitInfo.splitInfo,
uuid,
taskId.getId.toString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ object OptimizeTableCommandOverwrites extends Logging {
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val returnedMetrics =
datasourceJniWrapper.nativeMergeMTParts(
planWithSplitInfo.plan,
planWithSplitInfo.splitInfo,
uuid,
taskId.getId.toString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ object OptimizeTableCommandOverwrites extends Logging {
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val returnedMetrics =
datasourceJniWrapper.nativeMergeMTParts(
planWithSplitInfo.plan,
planWithSplitInfo.splitInfo,
uuid,
taskId.getId.toString,
Expand Down Expand Up @@ -172,7 +171,7 @@ object OptimizeTableCommandOverwrites extends Logging {
bucketNum: String,
bin: Seq[AddFile],
maxFileSize: Long): Seq[FileAction] = {
val tableV2 = ClickHouseTableV2.getTable(txn.deltaLog);
val tableV2 = ClickHouseTableV2.getTable(txn.deltaLog)

val sparkSession = SparkSession.getActiveSession.get

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@

public class CHDatasourceJniWrapper {

public native long nativeInitFileWriterWrapper(
String filePath, byte[] preferredSchema, String formatHint);
public native void write(long instanceId, long blockAddress);

public native String close(long instanceId);

/// FileWriter
public native long createFilerWriter(String filePath, byte[] preferredSchema, String formatHint);

public native long nativeInitMergeTreeWriterWrapper(
byte[] plan,
/// MergeTreeWriter
public native long createMergeTreeWriter(
byte[] splitInfo,
String uuid,
String taskId,
Expand All @@ -31,43 +35,28 @@ public native long nativeInitMergeTreeWriterWrapper(
byte[] confArray);

public native String nativeMergeMTParts(
byte[] plan,
byte[] splitInfo,
String uuid,
String taskId,
String partition_dir,
String bucket_dir);
byte[] splitInfo, String uuid, String taskId, String partition_dir, String bucket_dir);

public static native String filterRangesOnDriver(byte[] plan, byte[] read);

public native void write(long instanceId, long blockAddress);

public native void writeToMergeTree(long instanceId, long blockAddress);

public native void close(long instanceId);

public native String closeMergeTreeWriter(long instanceId);

/*-
/**
* The input block is already sorted by partition columns + bucket expressions. (check
* org.apache.spark.sql.execution.datasources.FileFormatWriter#write)
* However, the input block may contain parts(we call it stripe here) belonging to
* different partition/buckets.
* org.apache.spark.sql.execution.datasources.FileFormatWriter#write) However, the input block may
* contain parts(we call it stripe here) belonging to different partition/buckets.
*
* If bucketing is enabled, the input block's last column is guaranteed to be _bucket_value_.
* <p>If bucketing is enabled, the input block's last column is guaranteed to be _bucket_value_.
*
* This function splits the input block in to several blocks, each of which belonging
* to the same partition/bucket. Notice the stripe will NOT contain partition columns
* <p>This function splits the input block in to several blocks, each of which belonging to the
* same partition/bucket. Notice the stripe will NOT contain partition columns
*
* Since all rows in a stripe share the same partition/bucket,
* we only need to check the heading row.
* So, for each stripe, the native code also returns each stripe's first row's index.
* Caller can use these indice to get UnsafeRows from the input block,
* to help FileFormatDataWriter to aware partition/bucket changes.
* <p>Since all rows in a stripe share the same partition/bucket, we only need to check the
* heading row. So, for each stripe, the native code also returns each stripe's first row's index.
* Caller can use these indices to get UnsafeRows from the input block, to help
* FileFormatDataWriter to aware partition/bucket changes.
*/
public static native BlockStripes splitBlockByPartitionAndBucket(
long blockAddress,
int[] partitionColIndice,
int[] partitionColIndices,
boolean hasBucket,
boolean reserve_partition_columns);
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ trait CHFormatWriterInjects extends GlutenFormatWriterInjectsBase {
val namedStruct = namedStructBuilder.build

val instance =
datasourceJniWrapper.nativeInitFileWriterWrapper(path, namedStruct.toByteArray, formatName)
datasourceJniWrapper.createFilerWriter(path, namedStruct.toByteArray, formatName)

new OutputWriter {
override def write(row: InternalRow): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ class CHMergeTreeWriterInjects extends CHFormatWriterInjects {
splitInfo: Array[Byte]): OutputWriter = {
val datasourceJniWrapper = new CHDatasourceJniWrapper()
val instance =
datasourceJniWrapper.nativeInitMergeTreeWriterWrapper(
null,
datasourceJniWrapper.createMergeTreeWriter(
splitInfo,
UUID.randomUUID.toString,
context.getTaskAttemptID.getTaskID.getId.toString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ abstract class MergeTreeFileFormatDataWriter(
protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]()
protected var currentWriter: OutputWriter = _

protected val returnedMetrics = mutable.HashMap[String, AddFile]()
protected val returnedMetrics: mutable.Map[String, AddFile] = mutable.HashMap[String, AddFile]()

/** Trackers for computing various statistics on the data as it's being written out. */
protected val statsTrackers: Seq[WriteTaskStatsTracker] =
Expand All @@ -71,10 +71,10 @@ abstract class MergeTreeFileFormatDataWriter(
try {
currentWriter.close()
statsTrackers.foreach(_.closeFile(currentWriter.path()))
val ret = currentWriter.asInstanceOf[MergeTreeOutputWriter].getAddFiles
if (ret.nonEmpty) {
ret.foreach(addFile => returnedMetrics.put(addFile.path, addFile))
}
currentWriter
.asInstanceOf[MergeTreeOutputWriter]
.getAddFiles
.foreach(addFile => returnedMetrics.put(addFile.path, addFile))
} finally {
currentWriter = null
}
Expand Down Expand Up @@ -117,12 +117,7 @@ abstract class MergeTreeFileFormatDataWriter(
releaseResources()
val (taskCommitMessage, taskCommitTime) = Utils.timeTakenMs {
// committer.commitTask(taskAttemptContext)
val statuses = returnedMetrics
.map(
v => {
v._2
})
.toSeq
val statuses = returnedMetrics.map(_._2).toSeq
new TaskCommitMessage(statuses)
}

Expand All @@ -142,7 +137,7 @@ abstract class MergeTreeFileFormatDataWriter(

override def close(): Unit = {}

def getReturnedMetrics(): mutable.Map[String, AddFile] = returnedMetrics
def getReturnedMetrics: mutable.Map[String, AddFile] = returnedMetrics
}

/** FileFormatWriteTask for empty partitions */
Expand Down Expand Up @@ -443,7 +438,11 @@ class MergeTreeDynamicPartitionDataSingleWriter(
case fakeRow: FakeRow =>
if (fakeRow.batch.numRows() > 0) {
val blockStripes = GlutenRowSplitter.getInstance
.splitBlockByPartitionAndBucket(fakeRow, partitionColIndice, isBucketed, true)
.splitBlockByPartitionAndBucket(
fakeRow,
partitionColIndice,
isBucketed,
reserve_partition_columns = true)

val iter = blockStripes.iterator()
while (iter.hasNext) {
Expand Down Expand Up @@ -526,10 +525,10 @@ class MergeTreeDynamicPartitionDataConcurrentWriter(
if (status.outputWriter != null) {
try {
status.outputWriter.close()
val ret = status.outputWriter.asInstanceOf[MergeTreeOutputWriter].getAddFiles
if (ret.nonEmpty) {
ret.foreach(addFile => returnedMetrics.put(addFile.path, addFile))
}
status.outputWriter
.asInstanceOf[MergeTreeOutputWriter]
.getAddFiles
.foreach(addFile => returnedMetrics.put(addFile.path, addFile))
} finally {
status.outputWriter = null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ class MergeTreeOutputWriter(

if (nextBatch.numRows > 0) {
val col = nextBatch.column(0).asInstanceOf[CHColumnVector]
datasourceJniWrapper.writeToMergeTree(instance, col.getBlockAddress)
datasourceJniWrapper.write(instance, col.getBlockAddress)
} // else ignore this empty block
}

override def close(): Unit = {
val returnedMetrics = datasourceJniWrapper.closeMergeTreeWriter(instance)
val returnedMetrics = datasourceJniWrapper.close(instance)
if (returnedMetrics != null && returnedMetrics.nonEmpty) {
addFiles.appendAll(
AddFileTags.partsMetricsToAddFile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class GlutenClickHouseMergeTreeWriteSuite
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.sql.adaptive.enabled", "true")
.set("spark.sql.files.maxPartitionBytes", "20000000")
.set("spark.gluten.sql.native.writer.enabled", "true")
.setCHSettings("min_insert_block_size_rows", 100000)
.setCHSettings("mergetree.merge_after_insert", false)
.setCHSettings("input_format_parquet_max_block_size", 8192)
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ foreach(child ${children})
add_headers_and_sources(function_parsers ${child})
endforeach()

# Notice: soures files under Parser/*_udf subdirectories must be built into
# Notice: sources files under Parser/*_udf subdirectories must be built into
# target ${LOCALENGINE_SHARED_LIB} directly to make sure all function parsers
# are registered successly.
# are registered successfully.
add_library(
${LOCALENGINE_SHARED_LIB} SHARED
local_engine_jni.cpp ${local_udfs_sources} ${function_parsers_sources}
Expand Down
3 changes: 1 addition & 2 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
#include <IO/SharedThreadPools.h>
#include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Parser/RelParsers/RelParser.h>
#include <Parser/SerializedPlanParser.h>
#include <Parser/SubstraitParserUtils.h>
#include <Planner/PlannerActionsVisitor.h>
#include <Processors/Chunk.h>
#include <Processors/QueryPlan/ExpressionStep.h>
Expand All @@ -62,7 +62,6 @@
#include <Storages/SubstraitSource/ReadBufferBuilder.h>
#include <boost/algorithm/string/case_conv.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <google/protobuf/util/json_util.h>
#include <sys/resource.h>
#include <Poco/Logger.h>
#include <Poco/Util/MapConfiguration.h>
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <Parser/TypeParser.h>
#include <Processors/Transforms/ExpressionTransform.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <Storages/Output/FileWriterWrappers.h>
#include <Storages/Output/NormalFileWriter.h>
#include <google/protobuf/wrappers.pb.h>
#include <substrait/algebra.pb.h>
#include <substrait/type.pb.h>
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@
#include <Processors/QueryPlan/QueryPlan.h>
#include <QueryPipeline/printPipeline.h>
#include <Storages/MergeTree/MergeTreeData.h>
#include <Storages/Output/FileWriterWrappers.h>
#include <Storages/SubstraitSource/SubstraitFileSource.h>
#include <Storages/SubstraitSource/SubstraitFileSourceStep.h>
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/wrappers.pb.h>
#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/GlutenConfig.h>
#include <Common/JNIUtils.h>
Expand Down
6 changes: 2 additions & 4 deletions cpp-ch/local-engine/Parser/SubstraitParserUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ namespace local_engine
{
void logDebugMessage(const google::protobuf::Message & message, const char * type)
{
auto * logger = &Poco::Logger::get("SubstraitPlan");
if (logger->debug())
if (auto * logger = &Poco::Logger::get("SubstraitPlan"); logger->debug())
{
namespace pb_util = google::protobuf::util;
pb_util::JsonOptions options;
std::string json;
auto s = pb_util::MessageToJsonString(message, &json, options);
if (!s.ok())
if (auto s = pb_util::MessageToJsonString(message, &json, options); !s.ok())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Can not convert {} to Json", type);
LOG_DEBUG(logger, "{}:\n{}", type, json);
}
Expand Down
16 changes: 5 additions & 11 deletions cpp-ch/local-engine/Shuffle/NativeSplitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,10 @@
#include <string>
#include <Core/Block.h>
#include <Functions/FunctionFactory.h>
#include <Parser/SerializedPlanParser.h>
#include <base/types.h>
#include <boost/asio/detail/eventfd_select_interrupter.hpp>
#include <jni/jni_common.h>
#include <Poco/Logger.h>
#include <Poco/StringTokenizer.h>
#include <Common/Exception.h>
#include <Common/JNIUtils.h>
#include <Common/logger_useful.h>
#include <Storages/IO/AggregateSerializationUtils.h>

namespace local_engine
{
Expand Down Expand Up @@ -86,7 +80,7 @@ void NativeSplitter::split(DB::Block & block)
{
if (partition_buffer[i]->size() >= options.buffer_size)
{
output_buffer.emplace(std::pair(i, std::make_unique<Block>(partition_buffer[i]->releaseColumns())));
output_buffer.emplace(std::pair(i, std::make_unique<DB::Block>(partition_buffer[i]->releaseColumns())));
}
}
}
Expand Down Expand Up @@ -116,7 +110,7 @@ bool NativeSplitter::hasNext()
{
if (inputHasNext())
{
split(*reinterpret_cast<Block *>(inputNext()));
split(*reinterpret_cast<DB::Block *>(inputNext()));
}
else
{
Expand All @@ -125,7 +119,7 @@ bool NativeSplitter::hasNext()
auto buffer = partition_buffer.at(i);
if (buffer->size() > 0)
{
output_buffer.emplace(std::pair(i, new Block(buffer->releaseColumns())));
output_buffer.emplace(std::pair(i, new DB::Block(buffer->releaseColumns())));
}
}
break;
Expand Down Expand Up @@ -214,7 +208,7 @@ HashNativeSplitter::HashNativeSplitter(NativeSplitter::Options options_, jobject
selector_builder = std::make_unique<HashSelectorBuilder>(options.partition_num, hash_fields, options_.hash_algorithm);
}

void HashNativeSplitter::computePartitionId(Block & block)
void HashNativeSplitter::computePartitionId(DB::Block & block)
{
partition_info = selector_builder->build(block);
}
Expand All @@ -229,7 +223,7 @@ RoundRobinNativeSplitter::RoundRobinNativeSplitter(NativeSplitter::Options optio
selector_builder = std::make_unique<RoundRobinSelectorBuilder>(options_.partition_num);
}

void RoundRobinNativeSplitter::computePartitionId(Block & block)
void RoundRobinNativeSplitter::computePartitionId(DB::Block & block)
{
partition_info = selector_builder->build(block);
}
Expand Down
Loading

0 comments on commit 6aa83c8

Please sign in to comment.