Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix UAF in shared UDF #54592

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions be/src/exprs/java_function_call_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,9 @@ struct UDFFunctionCallHelper {
RETURN_IF_UNLIKELY(!st.ok(), ColumnHelper::create_const_null_column(size));

// call UDF method
jobject res = helper.batch_call(fn_desc->call_stub.get(), input_col_objs.data(), input_col_objs.size(), size);
// The ctx of the current function argument is not the same as the ctx of fn_desc->call_stub.
// The latter is created in JavaFunctionCallExpr::prepare and used in java udf, so we should
// use it to determine whether an exception has occurred.
FunctionContext* java_udf_ctx = fn_desc->call_stub.get()->ctx();
if (java_udf_ctx != nullptr && java_udf_ctx->has_error()) {
return Status::RuntimeError(java_udf_ctx->error_msg());
}
ASSIGN_OR_RETURN(auto res, helper.batch_call(fn_desc->call_stub.get(), input_col_objs.data(),
input_col_objs.size(), size));

RETURN_IF_UNLIKELY_NULL(res, ColumnHelper::create_const_null_column(size));
// get result
auto result_cols = get_boxed_result(ctx, res, size);
Expand Down Expand Up @@ -165,7 +160,7 @@ bool JavaFunctionCallExpr::is_constant() const {
}

StatusOr<std::shared_ptr<JavaUDFContext>> JavaFunctionCallExpr::_build_udf_func_desc(
ExprContext* context, FunctionContext::FunctionStateScope scope, const std::string& libpath) {
FunctionContext::FunctionStateScope scope, const std::string& libpath) {
auto desc = std::make_shared<JavaUDFContext>();
// init class loader and analyzer
desc->udf_classloader = std::make_unique<ClassLoader>(std::move(libpath));
Expand Down Expand Up @@ -208,9 +203,8 @@ StatusOr<std::shared_ptr<JavaUDFContext>> JavaFunctionCallExpr::_build_udf_func_
ASSIGN_OR_RETURN(auto update_stub_clazz, desc->udf_classloader->genCallStub(stub_clazz, udf_clazz, update_method,
ClassLoader::BATCH_EVALUATE));
ASSIGN_OR_RETURN(auto method, desc->analyzer->get_method_object(update_stub_clazz.clazz(), stub_method_name));
auto function_ctx = context->fn_context(_fn_context_index);
desc->call_stub = std::make_unique<BatchEvaluateStub>(
function_ctx, desc->udf_handle.handle(), std::move(update_stub_clazz), JavaGlobalRef(std::move(method)));
desc->call_stub = std::make_unique<BatchEvaluateStub>(desc->udf_handle.handle(), std::move(update_stub_clazz),
JavaGlobalRef(std::move(method)));

if (desc->prepare != nullptr) {
// we only support fragment local scope to call prepare
Expand Down Expand Up @@ -238,10 +232,10 @@ Status JavaFunctionCallExpr::open(RuntimeState* state, ExprContext* context,
}
// cacheable
if (scope == FunctionContext::FRAGMENT_LOCAL) {
auto get_func_desc = [this, scope, context, state](const std::string& lib) -> StatusOr<std::any> {
auto get_func_desc = [this, scope, state](const std::string& lib) -> StatusOr<std::any> {
std::any func_desc;
auto call = [&]() {
ASSIGN_OR_RETURN(func_desc, _build_udf_func_desc(context, scope, lib));
ASSIGN_OR_RETURN(func_desc, _build_udf_func_desc(scope, lib));
return Status::OK();
};
RETURN_IF_ERROR(call_function_in_pthread(state, call)->get_future().get());
Expand Down
3 changes: 1 addition & 2 deletions be/src/exprs/java_function_call_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class JavaFunctionCallExpr final : public Expr {
bool is_constant() const override;

private:
StatusOr<std::shared_ptr<JavaUDFContext>> _build_udf_func_desc(ExprContext* context,
FunctionContext::FunctionStateScope scope,
StatusOr<std::shared_ptr<JavaUDFContext>> _build_udf_func_desc(FunctionContext::FunctionStateScope scope,
const std::string& libpath);
void _call_udf_close();
RuntimeState* _runtime_state = nullptr;
Expand Down
6 changes: 3 additions & 3 deletions be/src/udf/java/java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ void JVMFunctionHelper::batch_update_if_not_null(FunctionContext* ctx, jobject u
CHECK_UDF_CALL_EXCEPTION(_env, ctx);
}

jobject JVMFunctionHelper::batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows) {
StatusOr<jobject> JVMFunctionHelper::batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows) {
return stub->batch_evaluate(rows, input, cols);
}

Expand Down Expand Up @@ -853,7 +853,7 @@ void AggBatchCallStub::batch_update_single(int num_rows, jobject state, jobject*
CHECK_UDF_CALL_EXCEPTION(env, this->_ctx);
}

jobject BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols) {
StatusOr<jobject> BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols) {
jvalue jni_inputs[2 + cols];
jni_inputs[0].i = num_rows;
jni_inputs[1].l = _caller;
Expand All @@ -863,7 +863,7 @@ jobject BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols
auto* env = JVMFunctionHelper::getInstance().getEnv();
auto res = env->CallStaticObjectMethodA(_stub_clazz.clazz(), env->FromReflectedMethod(_stub_method.handle()),
jni_inputs);
CHECK_UDF_CALL_EXCEPTION(env, this->_ctx);
RETURN_ERROR_IF_JNI_EXCEPTION(env);
return res;
}

Expand Down
10 changes: 4 additions & 6 deletions be/src/udf/java/java_udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class JVMFunctionHelper {
void batch_update_state(FunctionContext* ctx, jobject udaf, jobject update, jobject* input, int cols);

// batch call evalute by callstub
jobject batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows);
StatusOr<jobject> batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows);
// batch call method by reflect
jobject batch_call(FunctionContext* ctx, jobject caller, jobject method, jobject* input, int cols, int rows);
// batch call no-args function by reflect
Expand Down Expand Up @@ -347,14 +347,12 @@ class BatchEvaluateStub {
static inline const char* stub_clazz_name = "com.starrocks.udf.gen.CallStub";
static inline const char* batch_evaluate_method_name = "batchCallV";

BatchEvaluateStub(FunctionContext* ctx, jobject caller, JVMClass&& clazz, JavaGlobalRef&& method)
: _ctx(ctx), _caller(caller), _stub_clazz(std::move(clazz)), _stub_method(std::move(method)) {}
BatchEvaluateStub(jobject caller, JVMClass&& clazz, JavaGlobalRef&& method)
: _caller(caller), _stub_clazz(std::move(clazz)), _stub_method(std::move(method)) {}

FunctionContext* ctx() { return _ctx; }
jobject batch_evaluate(int num_rows, jobject* input, int cols);
StatusOr<jobject> batch_evaluate(int num_rows, jobject* input, int cols);

private:
FunctionContext* _ctx;
jobject _caller;
JVMClass _stub_clazz;
JavaGlobalRef _stub_method;
Expand Down
21 changes: 20 additions & 1 deletion test/sql/test_udf/R/test_jvm_udf
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ PROPERTIES
);
-- result:
-- !result
CREATE FUNCTION shared_exception_test(string)
RETURNS string
PROPERTIES
(
"symbol" = "ExceptionUDF2",
"isolation"="shared",
"type" = "StarrocksJar",
"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar"
);
-- result:
-- !result
CREATE TABLE `t0` (
`c0` int(11) NULL COMMENT "",
`c1` varchar(20) NULL COMMENT "",
Expand Down Expand Up @@ -152,4 +163,12 @@ select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3))
select count(*) from t0 where exception_test(c1) is null;
-- result:
[REGEX].*java.lang.NullPointerException.*
-- !result
-- !result
select count(*) from t0 where shared_exception_test(c1) is null;
-- result:
[REGEX].*java.lang.NullPointerException.*
-- !result
select count(*) from t0 where shared_exception_test(c1) is null;
-- result:
[REGEX].*java.lang.NullPointerException.*
-- !result
15 changes: 14 additions & 1 deletion test/sql/test_udf/T/test_jvm_udf
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ PROPERTIES
"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar"
);

CREATE FUNCTION shared_exception_test(string)
RETURNS string
PROPERTIES
(
"symbol" = "ExceptionUDF2",
"isolation"="shared",
"type" = "StarrocksJar",
"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar"
);

CREATE TABLE `t0` (
`c0` int(11) NULL COMMENT "",
`c1` varchar(20) NULL COMMENT "",
Expand Down Expand Up @@ -94,4 +104,7 @@ set spill_mode="force";

select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3)) as delta from t0 group by c0,c1) tb;
-- test udf exception case
select count(*) from t0 where exception_test(c1) is null;
select count(*) from t0 where exception_test(c1) is null;
-- run two times
select count(*) from t0 where shared_exception_test(c1) is null;
select count(*) from t0 where shared_exception_test(c1) is null;
Loading