Skip to content

Commit

Permalink
[BugFix] Fix UAF in shared UDF (#54592)
Browse files Browse the repository at this point in the history
Signed-off-by: stdpain <[email protected]>
(cherry picked from commit a29b2b6)
  • Loading branch information
stdpain authored and mergify[bot] committed Jan 2, 2025
1 parent 971f2a8 commit e19145a
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 27 deletions.
22 changes: 8 additions & 14 deletions be/src/exprs/java_function_call_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,9 @@ struct UDFFunctionCallHelper {
RETURN_IF_UNLIKELY(!st.ok(), ColumnHelper::create_const_null_column(size));

// call UDF method
jobject res = helper.batch_call(fn_desc->call_stub.get(), input_col_objs.data(), input_col_objs.size(), size);
// The ctx of the current function argument is not the same as the ctx of fn_desc->call_stub.
// The latter is created in JavaFunctionCallExpr::prepare and used in java udf, so we should
// use it to determine whether an exception has occurred.
FunctionContext* java_udf_ctx = fn_desc->call_stub.get()->ctx();
if (java_udf_ctx != nullptr && java_udf_ctx->has_error()) {
return Status::RuntimeError(java_udf_ctx->error_msg());
}
ASSIGN_OR_RETURN(auto res, helper.batch_call(fn_desc->call_stub.get(), input_col_objs.data(),
input_col_objs.size(), size));

RETURN_IF_UNLIKELY_NULL(res, ColumnHelper::create_const_null_column(size));
// get result
auto result_cols = get_boxed_result(ctx, res, size);
Expand Down Expand Up @@ -166,7 +161,7 @@ bool JavaFunctionCallExpr::is_constant() const {
}

StatusOr<std::shared_ptr<JavaUDFContext>> JavaFunctionCallExpr::_build_udf_func_desc(
ExprContext* context, FunctionContext::FunctionStateScope scope, const std::string& libpath) {
FunctionContext::FunctionStateScope scope, const std::string& libpath) {
auto desc = std::make_shared<JavaUDFContext>();
// init class loader and analyzer
desc->udf_classloader = std::make_unique<ClassLoader>(std::move(libpath));
Expand Down Expand Up @@ -209,9 +204,8 @@ StatusOr<std::shared_ptr<JavaUDFContext>> JavaFunctionCallExpr::_build_udf_func_
ASSIGN_OR_RETURN(auto update_stub_clazz, desc->udf_classloader->genCallStub(stub_clazz, udf_clazz, update_method,
ClassLoader::BATCH_EVALUATE));
ASSIGN_OR_RETURN(auto method, desc->analyzer->get_method_object(update_stub_clazz.clazz(), stub_method_name));
auto function_ctx = context->fn_context(_fn_context_index);
desc->call_stub = std::make_unique<BatchEvaluateStub>(
function_ctx, desc->udf_handle.handle(), std::move(update_stub_clazz), JavaGlobalRef(std::move(method)));
desc->call_stub = std::make_unique<BatchEvaluateStub>(desc->udf_handle.handle(), std::move(update_stub_clazz),
JavaGlobalRef(std::move(method)));

if (desc->prepare != nullptr) {
// we only support fragment local scope to call prepare
Expand Down Expand Up @@ -239,10 +233,10 @@ Status JavaFunctionCallExpr::open(RuntimeState* state, ExprContext* context,
}
// cacheable
if (scope == FunctionContext::FRAGMENT_LOCAL) {
auto get_func_desc = [this, scope, context, state](const std::string& lib) -> StatusOr<std::any> {
auto get_func_desc = [this, scope, state](const std::string& lib) -> StatusOr<std::any> {
std::any func_desc;
auto call = [&]() {
ASSIGN_OR_RETURN(func_desc, _build_udf_func_desc(context, scope, lib));
ASSIGN_OR_RETURN(func_desc, _build_udf_func_desc(scope, lib));
return Status::OK();
};
RETURN_IF_ERROR(call_function_in_pthread(state, call)->get_future().get());
Expand Down
3 changes: 1 addition & 2 deletions be/src/exprs/java_function_call_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class JavaFunctionCallExpr final : public Expr {
bool is_constant() const override;

private:
StatusOr<std::shared_ptr<JavaUDFContext>> _build_udf_func_desc(ExprContext* context,
FunctionContext::FunctionStateScope scope,
StatusOr<std::shared_ptr<JavaUDFContext>> _build_udf_func_desc(FunctionContext::FunctionStateScope scope,
const std::string& libpath);
void _call_udf_close();
RuntimeState* _runtime_state = nullptr;
Expand Down
6 changes: 3 additions & 3 deletions be/src/udf/java/java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ void JVMFunctionHelper::batch_update_if_not_null(FunctionContext* ctx, jobject u
CHECK_UDF_CALL_EXCEPTION(_env, ctx);
}

jobject JVMFunctionHelper::batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows) {
StatusOr<jobject> JVMFunctionHelper::batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows) {
return stub->batch_evaluate(rows, input, cols);
}

Expand Down Expand Up @@ -853,7 +853,7 @@ void AggBatchCallStub::batch_update_single(int num_rows, jobject state, jobject*
CHECK_UDF_CALL_EXCEPTION(env, this->_ctx);
}

jobject BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols) {
StatusOr<jobject> BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols) {
jvalue jni_inputs[2 + cols];
jni_inputs[0].i = num_rows;
jni_inputs[1].l = _caller;
Expand All @@ -863,7 +863,7 @@ jobject BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols
auto* env = JVMFunctionHelper::getInstance().getEnv();
auto res = env->CallStaticObjectMethodA(_stub_clazz.clazz(), env->FromReflectedMethod(_stub_method.handle()),
jni_inputs);
CHECK_UDF_CALL_EXCEPTION(env, this->_ctx);
RETURN_ERROR_IF_JNI_EXCEPTION(env);
return res;
}

Expand Down
10 changes: 4 additions & 6 deletions be/src/udf/java/java_udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class JVMFunctionHelper {
void batch_update_state(FunctionContext* ctx, jobject udaf, jobject update, jobject* input, int cols);

// batch call evalute by callstub
jobject batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows);
StatusOr<jobject> batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows);
// batch call method by reflect
jobject batch_call(FunctionContext* ctx, jobject caller, jobject method, jobject* input, int cols, int rows);
// batch call no-args function by reflect
Expand Down Expand Up @@ -347,14 +347,12 @@ class BatchEvaluateStub {
static inline const char* stub_clazz_name = "com.starrocks.udf.gen.CallStub";
static inline const char* batch_evaluate_method_name = "batchCallV";

BatchEvaluateStub(FunctionContext* ctx, jobject caller, JVMClass&& clazz, JavaGlobalRef&& method)
: _ctx(ctx), _caller(caller), _stub_clazz(std::move(clazz)), _stub_method(std::move(method)) {}
BatchEvaluateStub(jobject caller, JVMClass&& clazz, JavaGlobalRef&& method)
: _caller(caller), _stub_clazz(std::move(clazz)), _stub_method(std::move(method)) {}

FunctionContext* ctx() { return _ctx; }
jobject batch_evaluate(int num_rows, jobject* input, int cols);
StatusOr<jobject> batch_evaluate(int num_rows, jobject* input, int cols);

private:
FunctionContext* _ctx;
jobject _caller;
JVMClass _stub_clazz;
JavaGlobalRef _stub_method;
Expand Down
21 changes: 20 additions & 1 deletion test/sql/test_udf/R/test_jvm_udf
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ PROPERTIES
);
-- result:
-- !result
CREATE FUNCTION shared_exception_test(string)
RETURNS string
PROPERTIES
(
"symbol" = "ExceptionUDF2",
"isolation"="shared",
"type" = "StarrocksJar",
"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar"
);
-- result:
-- !result
CREATE TABLE `t0` (
`c0` int(11) NULL COMMENT "",
`c1` varchar(20) NULL COMMENT "",
Expand Down Expand Up @@ -152,4 +163,12 @@ select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3))
select count(*) from t0 where exception_test(c1) is null;
-- result:
[REGEX].*java.lang.NullPointerException.*
-- !result
-- !result
select count(*) from t0 where shared_exception_test(c1) is null;
-- result:
[REGEX].*java.lang.NullPointerException.*
-- !result
select count(*) from t0 where shared_exception_test(c1) is null;
-- result:
[REGEX].*java.lang.NullPointerException.*
-- !result
15 changes: 14 additions & 1 deletion test/sql/test_udf/T/test_jvm_udf
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ PROPERTIES
"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar"
);

CREATE FUNCTION shared_exception_test(string)
RETURNS string
PROPERTIES
(
"symbol" = "ExceptionUDF2",
"isolation"="shared",
"type" = "StarrocksJar",
"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar"
);

CREATE TABLE `t0` (
`c0` int(11) NULL COMMENT "",
`c1` varchar(20) NULL COMMENT "",
Expand Down Expand Up @@ -94,4 +104,7 @@ set spill_mode="force";

select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3)) as delta from t0 group by c0,c1) tb;
-- test udf exception case
select count(*) from t0 where exception_test(c1) is null;
select count(*) from t0 where exception_test(c1) is null;
-- run two times
select count(*) from t0 where shared_exception_test(c1) is null;
select count(*) from t0 where shared_exception_test(c1) is null;

0 comments on commit e19145a

Please sign in to comment.