Skip to content

Commit

Permalink
[BugFix] Fix UAF in shared UDF (#54592)
Browse files Browse the repository at this point in the history
Signed-off-by: stdpain <[email protected]>
(cherry picked from commit a29b2b6)

# Conflicts:
#	be/src/exprs/java_function_call_expr.cpp
#	be/src/exprs/java_function_call_expr.h
  • Loading branch information
stdpain authored and mergify[bot] committed Jan 2, 2025
1 parent 8b431fe commit 9e2a5d6
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 19 deletions.
84 changes: 76 additions & 8 deletions be/src/exprs/java_function_call_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,9 @@ struct UDFFunctionCallHelper {
RETURN_IF_UNLIKELY(!st.ok(), ColumnHelper::create_const_null_column(size));

// call UDF method
jobject res = helper.batch_call(fn_desc->call_stub.get(), input_col_objs.data(), input_col_objs.size(), size);
// The ctx of the current function argument is not the same as the ctx of fn_desc->call_stub.
// The latter is created in JavaFunctionCallExpr::prepare and used in java udf, so we should
// use it to determine whether an exception has occurred.
FunctionContext* java_udf_ctx = fn_desc->call_stub.get()->ctx();
if (java_udf_ctx != nullptr && java_udf_ctx->has_error()) {
return Status::RuntimeError(java_udf_ctx->error_msg());
}
ASSIGN_OR_RETURN(auto res, helper.batch_call(fn_desc->call_stub.get(), input_col_objs.data(),
input_col_objs.size(), size));

RETURN_IF_UNLIKELY_NULL(res, ColumnHelper::create_const_null_column(size));
// get result
auto result_cols = get_boxed_result(ctx, res, size);
Expand Down Expand Up @@ -165,6 +160,66 @@ bool JavaFunctionCallExpr::is_constant() const {
return Expr::is_constant();
}

<<<<<<< HEAD
=======
StatusOr<std::shared_ptr<JavaUDFContext>> JavaFunctionCallExpr::_build_udf_func_desc(
FunctionContext::FunctionStateScope scope, const std::string& libpath) {
auto desc = std::make_shared<JavaUDFContext>();
// init class loader and analyzer
desc->udf_classloader = std::make_unique<ClassLoader>(std::move(libpath));
RETURN_IF_ERROR(desc->udf_classloader->init());
desc->analyzer = std::make_unique<ClassAnalyzer>();

ASSIGN_OR_RETURN(desc->udf_class, desc->udf_classloader->getClass(_fn.scalar_fn.symbol));

auto add_method = [&](const std::string& name, std::unique_ptr<JavaMethodDescriptor>* res) {
bool has_method = false;
std::string method_name = name;
std::string signature;
std::vector<MethodTypeDescriptor> mtdesc;
RETURN_IF_ERROR(desc->analyzer->has_method(desc->udf_class.clazz(), method_name, &has_method));
if (has_method) {
RETURN_IF_ERROR(desc->analyzer->get_signature(desc->udf_class.clazz(), method_name, &signature));
RETURN_IF_ERROR(desc->analyzer->get_method_desc(signature, &mtdesc));
*res = std::make_unique<JavaMethodDescriptor>();
(*res)->name = std::move(method_name);
(*res)->signature = std::move(signature);
(*res)->method_desc = std::move(mtdesc);
ASSIGN_OR_RETURN((*res)->method, desc->analyzer->get_method_object(desc->udf_class.clazz(), name));
}
return Status::OK();
};

// Now we don't support prepare/close for UDF
// RETURN_IF_ERROR(add_method("prepare", &desc->prepare));
// RETURN_IF_ERROR(add_method("method_close", &desc->close));
RETURN_IF_ERROR(add_method("evaluate", &desc->evaluate));

// create UDF function instance
ASSIGN_OR_RETURN(desc->udf_handle, desc->udf_class.newInstance());
// BatchEvaluateStub
auto* stub_clazz = BatchEvaluateStub::stub_clazz_name;
auto* stub_method_name = BatchEvaluateStub::batch_evaluate_method_name;
auto udf_clazz = desc->udf_class.clazz();
auto update_method = desc->evaluate->method.handle();

ASSIGN_OR_RETURN(auto update_stub_clazz, desc->udf_classloader->genCallStub(stub_clazz, udf_clazz, update_method,
ClassLoader::BATCH_EVALUATE));
ASSIGN_OR_RETURN(auto method, desc->analyzer->get_method_object(update_stub_clazz.clazz(), stub_method_name));
desc->call_stub = std::make_unique<BatchEvaluateStub>(desc->udf_handle.handle(), std::move(update_stub_clazz),
JavaGlobalRef(std::move(method)));

if (desc->prepare != nullptr) {
// we only support fragment local scope to call prepare
if (scope == FunctionContext::FRAGMENT_LOCAL) {
// TODO: handle prepare function
}
}

return desc;
}

>>>>>>> a29b2b6b33 ([BugFix] Fix UAF in shared UDF (#54592))
Status JavaFunctionCallExpr::open(RuntimeState* state, ExprContext* context,
FunctionContext::FunctionStateScope scope) {
// init parent open
Expand All @@ -179,6 +234,7 @@ Status JavaFunctionCallExpr::open(RuntimeState* state, ExprContext* context,
const_columns.emplace_back(std::move(child_col));
}
}
<<<<<<< HEAD
auto open_state = [this, scope, context]() {
// init class loader and analyzer
std::string libpath;
Expand Down Expand Up @@ -208,6 +264,18 @@ Status JavaFunctionCallExpr::open(RuntimeState* state, ExprContext* context,
_func_desc->analyzer->get_method_object(_func_desc->udf_class.clazz(), name));
}
return Status::OK();
=======
// cacheable
if (scope == FunctionContext::FRAGMENT_LOCAL) {
auto get_func_desc = [this, scope, state](const std::string& lib) -> StatusOr<std::any> {
std::any func_desc;
auto call = [&]() {
ASSIGN_OR_RETURN(func_desc, _build_udf_func_desc(scope, lib));
return Status::OK();
};
RETURN_IF_ERROR(call_function_in_pthread(state, call)->get_future().get());
return func_desc;
>>>>>>> a29b2b6b33 ([BugFix] Fix UAF in shared UDF (#54592))
};

// Now we don't support prepare/close for UDF
Expand Down
5 changes: 5 additions & 0 deletions be/src/exprs/java_function_call_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class JavaFunctionCallExpr final : public Expr {
bool is_constant() const override;

private:
<<<<<<< HEAD
=======
StatusOr<std::shared_ptr<JavaUDFContext>> _build_udf_func_desc(FunctionContext::FunctionStateScope scope,
const std::string& libpath);
>>>>>>> a29b2b6b33 ([BugFix] Fix UAF in shared UDF (#54592))
void _call_udf_close();
RuntimeState* _runtime_state = nullptr;
std::shared_ptr<JavaUDFContext> _func_desc;
Expand Down
6 changes: 3 additions & 3 deletions be/src/udf/java/java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ void JVMFunctionHelper::batch_update_if_not_null(FunctionContext* ctx, jobject u
CHECK_UDF_CALL_EXCEPTION(_env, ctx);
}

jobject JVMFunctionHelper::batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows) {
StatusOr<jobject> JVMFunctionHelper::batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows) {
return stub->batch_evaluate(rows, input, cols);
}

Expand Down Expand Up @@ -860,7 +860,7 @@ void AggBatchCallStub::batch_update_single(int num_rows, jobject state, jobject*
CHECK_UDF_CALL_EXCEPTION(env, this->_ctx);
}

jobject BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols) {
StatusOr<jobject> BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols) {
jvalue jni_inputs[2 + cols];
jni_inputs[0].i = num_rows;
jni_inputs[1].l = _caller;
Expand All @@ -870,7 +870,7 @@ jobject BatchEvaluateStub::batch_evaluate(int num_rows, jobject* input, int cols
auto* env = JVMFunctionHelper::getInstance().getEnv();
auto res = env->CallStaticObjectMethodA(_stub_clazz.clazz(), env->FromReflectedMethod(_stub_method.handle()),
jni_inputs);
CHECK_UDF_CALL_EXCEPTION(env, this->_ctx);
RETURN_ERROR_IF_JNI_EXCEPTION(env);
return res;
}

Expand Down
10 changes: 4 additions & 6 deletions be/src/udf/java/java_udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class JVMFunctionHelper {
void batch_update_state(FunctionContext* ctx, jobject udaf, jobject update, jobject* input, int cols);

// batch call evalute by callstub
jobject batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows);
StatusOr<jobject> batch_call(BatchEvaluateStub* stub, jobject* input, int cols, int rows);
// batch call method by reflect
jobject batch_call(FunctionContext* ctx, jobject caller, jobject method, jobject* input, int cols, int rows);
// batch call no-args function by reflect
Expand Down Expand Up @@ -346,14 +346,12 @@ class BatchEvaluateStub {
static inline const char* stub_clazz_name = "com.starrocks.udf.gen.CallStub";
static inline const char* batch_evaluate_method_name = "batchCallV";

BatchEvaluateStub(FunctionContext* ctx, jobject caller, JVMClass&& clazz, JavaGlobalRef&& method)
: _ctx(ctx), _caller(caller), _stub_clazz(std::move(clazz)), _stub_method(std::move(method)) {}
BatchEvaluateStub(jobject caller, JVMClass&& clazz, JavaGlobalRef&& method)
: _caller(caller), _stub_clazz(std::move(clazz)), _stub_method(std::move(method)) {}

FunctionContext* ctx() { return _ctx; }
jobject batch_evaluate(int num_rows, jobject* input, int cols);
StatusOr<jobject> batch_evaluate(int num_rows, jobject* input, int cols);

private:
FunctionContext* _ctx;
jobject _caller;
JVMClass _stub_clazz;
JavaGlobalRef _stub_method;
Expand Down
21 changes: 20 additions & 1 deletion test/sql/test_udf/R/test_jvm_udf
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ PROPERTIES
);
-- result:
-- !result
CREATE FUNCTION shared_exception_test(string)
RETURNS string
PROPERTIES
(
"symbol" = "ExceptionUDF2",
"isolation"="shared",
"type" = "StarrocksJar",
"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar"
);
-- result:
-- !result
CREATE TABLE `t0` (
`c0` int(11) NULL COMMENT "",
`c1` varchar(20) NULL COMMENT "",
Expand Down Expand Up @@ -152,4 +163,12 @@ select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3))
select count(*) from t0 where exception_test(c1) is null;
-- result:
[REGEX].*java.lang.NullPointerException.*
-- !result
-- !result
select count(*) from t0 where shared_exception_test(c1) is null;
-- result:
[REGEX].*java.lang.NullPointerException.*
-- !result
select count(*) from t0 where shared_exception_test(c1) is null;
-- result:
[REGEX].*java.lang.NullPointerException.*
-- !result
15 changes: 14 additions & 1 deletion test/sql/test_udf/T/test_jvm_udf
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ PROPERTIES
"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar"
);

CREATE FUNCTION shared_exception_test(string)
RETURNS string
PROPERTIES
(
"symbol" = "ExceptionUDF2",
"isolation"="shared",
"type" = "StarrocksJar",
"file" = "${udf_url}/starrocks-jdbc/ExceptionUDF2.jar"
);

CREATE TABLE `t0` (
`c0` int(11) NULL COMMENT "",
`c1` varchar(20) NULL COMMENT "",
Expand Down Expand Up @@ -94,4 +104,7 @@ set spill_mode="force";

select sum(delta), count(*), count(delta) from (select (sum(c3) - sumbigint(c3)) as delta from t0 group by c0,c1) tb;
-- test udf exception case
select count(*) from t0 where exception_test(c1) is null;
select count(*) from t0 where exception_test(c1) is null;
-- run two times
select count(*) from t0 where shared_exception_test(c1) is null;
select count(*) from t0 where shared_exception_test(c1) is null;

0 comments on commit 9e2a5d6

Please sign in to comment.