Skip to content

Commit

Permalink
Add BatchWriteUtilTest
Browse files Browse the repository at this point in the history
Signed-off-by: PengFei Li <[email protected]>
  • Loading branch information
banmoy committed Nov 5, 2024
1 parent c4dc9d1 commit 74e915b
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 168 deletions.
2 changes: 1 addition & 1 deletion be/src/http/action/stream_load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ Status StreamLoadAction::_handle(StreamLoadContext* ctx) {
}

Status StreamLoadAction::_handle_batch_write(starrocks::HttpRequest* http_req, StreamLoadContext* ctx) {
ASSIGN_OR_RETURN(ctx->load_parameters, get_batch_write_load_parameters(http_req, ctx));
ASSIGN_OR_RETURN(ctx->load_parameters, get_load_parameters_from_http(http_req));
return _exec_env->batch_write_mgr()->append_data(ctx);
}

Expand Down
2 changes: 1 addition & 1 deletion be/src/runtime/batch_write/batch_write_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ void BatchWriteMgr::receive_rpc_request(ExecEnv* exec_env, brpc::Controller* cnt
}
}

auto ret = get_batch_write_load_parameters(parameters);
auto ret = get_load_parameters_from_brpc(parameters);
if (!ret.ok()) {
ctx->status = ret.status();
return;
Expand Down
155 changes: 60 additions & 95 deletions be/src/runtime/batch_write/batch_write_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,113 +14,78 @@

#include "runtime/batch_write/batch_write_util.h"

#include <vector>

#include "http/http_common.h"
#include "http/http_request.h"
#include "runtime/stream_load/stream_load_context.h"

namespace starrocks {

#define POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, param_name) \
do { \
if (!http_req->header(param_name).empty()) { \
load_params.emplace(param_name, http_req->header(param_name)); \
} \
} while (0);
const std::vector<std::string> LOAD_PARAMETER_NAMES = {HTTP_FORMAT_KEY,
HTTP_COLUMNS,
HTTP_WHERE,
HTTP_MAX_FILTER_RATIO,
HTTP_TIMEOUT,
HTTP_PARTITIONS,
HTTP_TEMP_PARTITIONS,
HTTP_NEGATIVE,
HTTP_STRICT_MODE,
HTTP_TIMEZONE,
HTTP_LOAD_MEM_LIMIT,
HTTP_EXEC_MEM_LIMIT,
HTTP_PARTIAL_UPDATE,
HTTP_PARTIAL_UPDATE_MODE,
HTTP_TRANSMISSION_COMPRESSION_TYPE,
HTTP_LOAD_DOP,
HTTP_MERGE_CONDITION,
HTTP_LOG_REJECTED_RECORD_NUM,
HTTP_COMPRESSION,
HTTP_WAREHOUSE,
HTTP_ENABLE_BATCH_WRITE,
HTTP_BATCH_WRITE_ASYNC,
HTTP_BATCH_WRITE_INTERVAL_MS,
HTTP_BATCH_WRITE_PARALLEL,
HTTP_COLUMN_SEPARATOR,
HTTP_ROW_DELIMITER,
HTTP_TRIM_SPACE,
HTTP_ENCLOSE,
HTTP_ESCAPE,
HTTP_JSONPATHS,
HTTP_JSONROOT,
HTTP_STRIP_OUTER_ARRAY};

StatusOr<LoadParams> get_batch_write_load_parameters(HttpRequest* http_req, StreamLoadContext* ctx) {
StatusOr<BatchWriteLoadParams> get_load_parameters(
const std::function<std::optional<std::string>(const std::string&)>& getter_func) {
std::map<std::string, std::string> load_params;
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_FORMAT_KEY);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_COLUMNS);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_WHERE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_MAX_FILTER_RATIO);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_TIMEOUT);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_PARTITIONS);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_TEMP_PARTITIONS);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_NEGATIVE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_STRICT_MODE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_TIMEZONE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_LOAD_MEM_LIMIT);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_EXEC_MEM_LIMIT);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_PARTIAL_UPDATE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_PARTIAL_UPDATE_MODE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_TRANSMISSION_COMPRESSION_TYPE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_LOAD_DOP);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_MERGE_CONDITION);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_LOG_REJECTED_RECORD_NUM);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_COMPRESSION);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_WAREHOUSE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_ENABLE_BATCH_WRITE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_BATCH_WRITE_ASYNC);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_BATCH_WRITE_INTERVAL_MS);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_BATCH_WRITE_PARALLEL);

// csv format parameters
if (ctx->format != TFileFormatType::FORMAT_JSON) {
if (!http_req->header(HTTP_SKIP_HEADER).empty()) {
return Status::NotSupported("Csv format not support skip header when enable batch write");
for (const auto& name : LOAD_PARAMETER_NAMES) {
auto value_opt = getter_func(name);
if (value_opt) {
load_params.emplace(name, *value_opt);
}
}
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_COLUMN_SEPARATOR);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_ROW_DELIMITER);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_TRIM_SPACE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_ENCLOSE);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_ESCAPE);

// json format parameters
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_JSONPATHS);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_JSONROOT);
POPULATE_LOAD_PARAMETER_FROM_HTTP(http_req, load_params, HTTP_STRIP_OUTER_ARRAY);

return load_params;
}

#define POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, param_name) \
do { \
auto it = input_params.find(param_name); \
if (it != input_params.end()) { \
load_params.emplace(param_name, it->second); \
} \
} while (0);

StatusOr<LoadParams> get_batch_write_load_parameters(const std::map<std::string, std::string>& input_params) {
std::map<std::string, std::string> load_params;
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_FORMAT_KEY);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_COLUMNS);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_WHERE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_MAX_FILTER_RATIO);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_TIMEOUT);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_PARTITIONS);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_TEMP_PARTITIONS);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_NEGATIVE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_STRICT_MODE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_TIMEZONE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_LOAD_MEM_LIMIT);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_EXEC_MEM_LIMIT);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_PARTIAL_UPDATE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_PARTIAL_UPDATE_MODE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_TRANSMISSION_COMPRESSION_TYPE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_LOAD_DOP);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_MERGE_CONDITION);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_LOG_REJECTED_RECORD_NUM);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_COMPRESSION);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_WAREHOUSE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_ENABLE_BATCH_WRITE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_BATCH_WRITE_ASYNC);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_BATCH_WRITE_INTERVAL_MS);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_BATCH_WRITE_PARALLEL);

// csv format parameters
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_COLUMN_SEPARATOR);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_ROW_DELIMITER);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_TRIM_SPACE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_ENCLOSE);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_ESCAPE);

// json format parameters
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_JSONPATHS);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_JSONROOT);
POPULATE_LOAD_PARAMETER_FROM_MAP(input_params, load_params, HTTP_STRIP_OUTER_ARRAY);
StatusOr<BatchWriteLoadParams> get_load_parameters_from_brpc(const std::map<std::string, std::string>& input_params) {
return get_load_parameters([&input_params](const std::string& param_name) -> std::optional<std::string> {
auto it = input_params.find(param_name);
if (it != input_params.end()) {
return it->second;
} else {
return std::nullopt;
}
});
}

return load_params;
StatusOr<BatchWriteLoadParams> get_load_parameters_from_http(HttpRequest* http_req) {
return get_load_parameters([http_req](const std::string& param_name) -> std::optional<std::string> {
std::string value = http_req->header(param_name);
if (!value.empty()) {
return value;
} else {
return std::nullopt;
}
});
}

} // namespace starrocks
55 changes: 50 additions & 5 deletions be/src/runtime/batch_write/batch_write_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,60 @@

#pragma once

#include "runtime/batch_write/isomorphic_batch_write.h"
#include <map>
#include <optional>

#include "common/statusor.h"

namespace starrocks {

class HttpRequest;
class StreamLoadContext;
using BatchWriteLoadParams = std::map<std::string, std::string>;

struct BatchWriteId {
std::string db;
std::string table;
BatchWriteLoadParams load_params;
};

// Hash function for BatchWriteId
struct BatchWriteIdHash {
std::size_t operator()(const BatchWriteId& id) const {
std::size_t hash = std::hash<std::string>{}(id.db);
hash ^= std::hash<std::string>{}(id.table) << 1;

for (const auto& param : id.load_params) {
hash ^= std::hash<std::string>{}(param.first) << 1;
hash ^= std::hash<std::string>{}(param.second) << 1;
}

StatusOr<LoadParams> get_batch_write_load_parameters(HttpRequest* http_req, StreamLoadContext* ctx);
return hash;
}
};

// Equality function for BatchWriteId
struct BatchWriteIdEqual {
bool operator()(const BatchWriteId& lhs, const BatchWriteId& rhs) const {
return lhs.db == rhs.db && lhs.table == rhs.table && lhs.load_params == rhs.load_params;
}
};

std::ostream& operator<<(std::ostream& out, const BatchWriteId& id) {
out << "db: " << id.db << ", table: " << id.table << ", load_params: {";
bool first = true;
for (const auto& [key, value] : id.load_params) {
if (!first) {
out << ",";
}
first = false;
out << key << ":" << value;
}
out << "}";
return out;
}

class HttpRequest;

StatusOr<LoadParams> get_batch_write_load_parameters(const std::map<std::string, std::string>& input_params);
StatusOr<BatchWriteLoadParams> get_load_parameters_from_http(HttpRequest* http_req);
StatusOr<BatchWriteLoadParams> get_load_parameters_from_brpc(const std::map<std::string, std::string>& input_params);

} // namespace starrocks
14 changes: 0 additions & 14 deletions be/src/runtime/batch_write/isomorphic_batch_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,6 @@

namespace starrocks {

std::ostream& operator<<(std::ostream& out, const BatchWriteId& id) {
out << "db: " << id.db << ", table: " << id.table << ", load_params: {";
bool first = true;
for (const auto& [key, value] : id.load_params) {
if (!first) {
out << ",";
}
first = false;
out << key << ":" << value;
}
out << "}";
return out;
}

class AsyncAppendDataContext {
public:
AsyncAppendDataContext(StreamLoadContext* data_ctx) : _data_ctx(data_ctx), _latch(1) { data_ctx->ref(); }
Expand Down
33 changes: 1 addition & 32 deletions be/src/runtime/batch_write/isomorphic_batch_write.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <unordered_set>

#include "common/statusor.h"
#include "runtime/batch_write/batch_write_util.h"
#include "util/countdown_latch.h"

namespace starrocks {
Expand All @@ -36,38 +37,6 @@ class ThreadPoolExecutor;

class StreamLoadContext;

using LoadParams = std::map<std::string, std::string>;

struct BatchWriteId {
std::string db;
std::string table;
LoadParams load_params;
};

std::ostream& operator<<(std::ostream& out, const BatchWriteId& id);

// Hash function for BatchWriteId
struct BatchWriteIdHash {
std::size_t operator()(const BatchWriteId& id) const {
std::size_t hash = std::hash<std::string>{}(id.db);
hash ^= std::hash<std::string>{}(id.table) << 1;

for (const auto& param : id.load_params) {
hash ^= std::hash<std::string>{}(param.first) << 1;
hash ^= std::hash<std::string>{}(param.second) << 1;
}

return hash;
}
};

// Equality function for BatchWriteId
struct BatchWriteIdEqual {
bool operator()(const BatchWriteId& lhs, const BatchWriteId& rhs) const {
return lhs.db == rhs.db && lhs.table == rhs.table && lhs.load_params == rhs.load_params;
}
};

using BThreadCountDownLatch = GenericCountDownLatch<bthread::Mutex, bthread::ConditionVariable>;

class AsyncAppendDataContext;
Expand Down
69 changes: 69 additions & 0 deletions be/test/runtime/batch_write/batch_write_util_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "runtime/batch_write/batch_write_util.h"

#include <gtest/gtest.h>

#include "http/http_common.h"
#include "http/http_request.h"

namespace starrocks {

TEST(BatchWriteUtilTest, batch_write_id_hash) {
BatchWriteId id1{"db1", "table1", {{"param1", "value1"}, {"param2", "value2"}}};
BatchWriteId id2{"db1", "table1", {{"param1", "value1"}, {"param2", "value2"}}};
BatchWriteId id3{"db2", "table2", {{"param3", "value3"}}};

BatchWriteIdHash hash_fn;
ASSERT_EQ(hash_fn(id1), hash_fn(id2));
ASSERT_NE(hash_fn(id1), hash_fn(id3));
}

TEST(BatchWriteUtilTest, batch_write_id_equal) {
BatchWriteId id1{"db1", "table1", {{"param1", "value1"}, {"param2", "value2"}}};
BatchWriteId id2{"db1", "table1", {{"param1", "value1"}, {"param2", "value2"}}};
BatchWriteId id3{"db2", "table2", {{"param3", "value3"}}};

BatchWriteIdEqual equal_fn;
ASSERT_TRUE(equal_fn(id1, id2));
ASSERT_FALSE(equal_fn(id1, id3));
}

TEST(BatchWriteUtilTest, get_load_parameters_from_brpc) {
std::map<std::string, std::string> input_params = {
{HTTP_FORMAT_KEY, "json"}, {HTTP_COLUMNS, "col1,col2"}, {HTTP_TIMEOUT, "30"}};
auto result = get_load_parameters_from_brpc(input_params);
ASSERT_TRUE(result.ok());
auto load_params = result.value();
ASSERT_EQ(load_params[HTTP_FORMAT_KEY], "json");
ASSERT_EQ(load_params[HTTP_COLUMNS], "col1,col2");
ASSERT_EQ(load_params[HTTP_TIMEOUT], "30");
}

TEST(BatchWriteUtilTest, get_load_parameters_from_http) {
HttpRequest http_req(nullptr);
http_req._headers.emplace(HTTP_FORMAT_KEY, "json");
http_req._headers.emplace(HTTP_COLUMNS, "col1,col2");
http_req._headers.emplace(HTTP_TIMEOUT, "30");

auto result = get_load_parameters_from_http(&http_req);
ASSERT_TRUE(result.ok());
auto load_params = result.value();
ASSERT_EQ(load_params[HTTP_FORMAT_KEY], "json");
ASSERT_EQ(load_params[HTTP_COLUMNS], "col1,col2");
ASSERT_EQ(load_params[HTTP_TIMEOUT], "30");
}

} // namespace starrocks
Loading

0 comments on commit 74e915b

Please sign in to comment.