From 2f96df0ee7d250a7a160258c761c03c11271f7a8 Mon Sep 17 00:00:00 2001 From: stdpain <34912776+stdpain@users.noreply.github.com> Date: Thu, 2 Jan 2025 21:42:07 +0800 Subject: [PATCH] [BugFix] Fix UAF in shared UDF (#54592) Signed-off-by: stdpain (cherry picked from commit a29b2b6b33943eaad553d8ec040fd083a4e7c355) --- be/src/exprs/java_function_call_expr.cpp | 22 ++++++++-------------- be/src/exprs/java_function_call_expr.h | 3 +-- be/src/udf/java/java_udf.cpp | 6 +++--- be/src/udf/java/java_udf.h | 10 ++++------ test/sql/test_udf/R/test_jvm_udf | 21 ++++++++++++++++++++- test/sql/test_udf/T/test_jvm_udf | 15 ++++++++++++++- 6 files changed, 50 insertions(+), 27 deletions(-) diff --git a/be/src/exprs/java_function_call_expr.cpp b/be/src/exprs/java_function_call_expr.cpp index 6eb0ab783ce03..282b582840a56 100644 --- a/be/src/exprs/java_function_call_expr.cpp +++ b/be/src/exprs/java_function_call_expr.cpp @@ -74,14 +74,9 @@ struct UDFFunctionCallHelper { RETURN_IF_UNLIKELY(!st.ok(), ColumnHelper::create_const_null_column(size)); // call UDF method - jobject res = helper.batch_call(fn_desc->call_stub.get(), input_col_objs.data(), input_col_objs.size(), size); - // The ctx of the current function argument is not the same as the ctx of fn_desc->call_stub. - // The latter is created in JavaFunctionCallExpr::prepare and used in java udf, so we should - // use it to determine whether an exception has occurred. - FunctionContext* java_udf_ctx = fn_desc->call_stub.get()->ctx(); - if (java_udf_ctx != nullptr && java_udf_ctx->has_error()) { - return Status::RuntimeError(java_udf_ctx->error_msg()); - } + ASSIGN_OR_RETURN(auto res, helper.batch_call(fn_desc->call_stub.get(), input_col_objs.data(), + input_col_objs.size(), size)); + RETURN_IF_UNLIKELY_NULL(res, ColumnHelper::create_const_null_column(size)); // get result auto result_cols = get_boxed_result(ctx, res, size); @@ -166,7 +161,7 @@ bool JavaFunctionCallExpr::is_constant() const { } StatusOr> JavaFunctionCallExpr::_build_udf_func_desc( - ExprContext* context, FunctionContext::FunctionStateScope scope, const std::string& libpath) { + FunctionContext::FunctionStateScope scope, const std::string& libpath) { auto desc = std::make_shared(); // init class loader and analyzer desc->udf_classloader = std::make_unique(std::move(libpath)); @@ -209,9 +204,8 @@ StatusOr> JavaFunctionCallExpr::_build_udf_func_ ASSIGN_OR_RETURN(auto update_stub_clazz, desc->udf_classloader->genCallStub(stub_clazz, udf_clazz, update_method, ClassLoader::BATCH_EVALUATE)); ASSIGN_OR_RETURN(auto method, desc->analyzer->get_method_object(update_stub_clazz.clazz(), stub_method_name)); - auto function_ctx = context->fn_context(_fn_context_index); - desc->call_stub = std::make_unique( - function_ctx, desc->udf_handle.handle(), std::move(update_stub_clazz), JavaGlobalRef(std::move(method))); + desc->call_stub = std::make_unique(desc->udf_handle.handle(), std::move(update_stub_clazz), + JavaGlobalRef(std::move(method))); if (desc->prepare != nullptr) { // we only support fragment local scope to call prepare @@ -239,10 +233,10 @@ Status JavaFunctionCallExpr::open(RuntimeState* state, ExprContext* context, } // cacheable if (scope == FunctionContext::FRAGMENT_LOCAL) { - auto get_func_desc = [this, scope, context, state](const std::string& lib) -> StatusOr { + auto get_func_desc = [this, scope, state](const std::string& lib) -> StatusOr { std::any func_desc; auto call = [&]() { - ASSIGN_OR_RETURN(func_desc, _build_udf_func_desc(context, scope, lib)); + ASSIGN_OR_RETURN(func_desc, _build_udf_func_desc(scope, lib)); return Status::OK(); }; RETURN_IF_ERROR(call_function_in_pthread(state, call)->get_future().get()); diff --git a/be/src/exprs/java_function_call_expr.h b/be/src/exprs/java_function_call_expr.h index 200e238b9d36c..8daac590da70d 100644 --- a/be/src/exprs/java_function_call_expr.h +++ b/be/src/exprs/java_function_call_expr.h @@ -37,8 +37,7 @@ class JavaFunctionCallExpr final : public Expr { bool is_constant() const override; private: - StatusOr> _build_udf_func_desc(ExprContext* context, - FunctionContext::FunctionStateScope scope, + StatusOr> _build_udf_func_desc(FunctionContext::FunctionStateScope scope, const std::string& libpath); void _call_udf_close(); RuntimeState* _runtime_state = nullptr; diff --git a/be/src/udf/java/java_udf.cpp b/be/src/udf/java/java_udf.cpp index 9f35977a54976..616b9941f1848 100644 --- a/be/src/udf/java/java_udf.cpp +++ b/be/src/udf/java/java_udf.cpp @@ -329,7 +329,7 @@ void JVMFunctionHelper::batch_update_if_not_null(FunctionContext* ctx, jobject u CHECK_UDF_CALL_EXCEPTION(_env, ctx); } -jobject JVMFunctionHelper::batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows) { +StatusOr JVMFunctionHelper::batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows) { return stub->batch_evaluate(rows, input, cols); } @@ -853,7 +853,7 @@ void AggBatchCallStub::batch_update_single(int num_rows, jobject state, jobject* CHECK_UDF_CALL_EXCEPTION(env, this->_ctx); } -jobject BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols) { +StatusOr BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols) { jvalue jni_inputs[2 + cols]; jni_inputs[0].i = num_rows; jni_inputs[1].l = _caller; @@ -863,7 +863,7 @@ jobject BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols auto* env = JVMFunctionHelper::getInstance().getEnv(); auto res = env->CallStaticObjectMethodA(_stub_clazz.clazz(), env->FromReflectedMethod(_stub_method.handle()), jni_inputs); - CHECK_UDF_CALL_EXCEPTION(env, this->_ctx); + RETURN_ERROR_IF_JNI_EXCEPTION(env); return res; } diff --git a/be/src/udf/java/java_udf.h b/be/src/udf/java/java_udf.h index fb059382a1027..5df5f634db956 100644 --- a/be/src/udf/java/java_udf.h +++ b/be/src/udf/java/java_udf.h @@ -91,7 +91,7 @@ class JVMFunctionHelper { void batch_update_state(FunctionContext* ctx, jobject udaf, jobject update, jobject* input, int cols); // batch call evalute by callstub - jobject batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows); + StatusOr batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows); // batch call method by reflect jobject batch_call(FunctionContext* ctx, jobject caller, jobject method, jobject* input, int cols, int rows); // batch call no-args function by reflect @@ -347,14 +347,12 @@ class BatchEvaluateStub { static inline const char* stub_clazz_name = "com.starrocks.udf.gen.CallStub"; static inline const char* batch_evaluate_method_name = "batchCallV"; - BatchEvaluateStub(FunctionContext* ctx, jobject caller, JVMClass&& clazz, JavaGlobalRef&& method) - : _ctx(ctx), _caller(caller), _stub_clazz(std::move(clazz)), _stub_method(std::move(method)) {} + BatchEvaluateStub(jobject caller, JVMClass&& clazz, JavaGlobalRef&& method) + : _caller(caller), _stub_clazz(std::move(clazz)), _stub_method(std::move(method)) {} - FunctionContext* ctx() { return _ctx; } - jobject batch_evaluate(int num_rows, jobject* input, int cols); + StatusOr batch_evaluate(int num_rows, jobject* input, int cols); private: - FunctionContext* _ctx; jobject _caller; JVMClass _stub_clazz; JavaGlobalRef _stub_method; diff --git a/test/sql/test_udf/R/test_jvm_udf b/test/sql/test_udf/R/test_jvm_udf index 4c8e3c8f50aa1..0ca02ef83dbc0 100644 --- a/test/sql/test_udf/R/test_jvm_udf +++ b/test/sql/test_udf/R/test_jvm_udf @@ -61,6 +61,17 @@ PROPERTIES ); -- result: -- !result +CREATE FUNCTION shared_exception_test(string) +RETURNS string +PROPERTIES +( +"symbol" = "ExceptionUDF2", +"isolation"="shared", +"type" = "StarrocksJar", +"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar" +); +-- result: +-- !result CREATE TABLE `t0` ( `c0` int(11) NULL COMMENT "", `c1` varchar(20) NULL COMMENT "", @@ -152,4 +163,12 @@ select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3)) select count(*) from t0 where exception_test(c1) is null; -- result: [REGEX].*java.lang.NullPointerException.* --- !result \ No newline at end of file +-- !result +select count(*) from t0 where shared_exception_test(c1) is null; +-- result: +[REGEX].*java.lang.NullPointerException.* +-- !result +select count(*) from t0 where shared_exception_test(c1) is null; +-- result: +[REGEX].*java.lang.NullPointerException.* +-- !result diff --git a/test/sql/test_udf/T/test_jvm_udf b/test/sql/test_udf/T/test_jvm_udf index 4138364c8f606..ca2240ec8c0ea 100644 --- a/test/sql/test_udf/T/test_jvm_udf +++ b/test/sql/test_udf/T/test_jvm_udf @@ -53,6 +53,16 @@ PROPERTIES "file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar" ); +CREATE FUNCTION shared_exception_test(string) +RETURNS string +PROPERTIES +( +"symbol" = "ExceptionUDF2", +"isolation"="shared", +"type" = "StarrocksJar", +"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar" +); + CREATE TABLE `t0` ( `c0` int(11) NULL COMMENT "", `c1` varchar(20) NULL COMMENT "", @@ -94,4 +104,7 @@ set spill_mode="force"; select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3)) as delta from t0 group by c0,c1) tb; -- test udf exception case -select count(*) from t0 where exception_test(c1) is null; \ No newline at end of file +select count(*) from t0 where exception_test(c1) is null; +-- run two times +select count(*) from t0 where shared_exception_test(c1) is null; +select count(*) from t0 where shared_exception_test(c1) is null; \ No newline at end of file