diff --git a/src/pyoptinterface/_src/codegen_c.py b/src/pyoptinterface/_src/codegen_c.py index 8d61256..a62ca05 100644 --- a/src/pyoptinterface/_src/codegen_c.py +++ b/src/pyoptinterface/_src/codegen_c.py @@ -38,12 +38,30 @@ def generate_csrc_prelude(io: IO[str]): io.write( """// includes -#include "stddef.h" -#include "math.h" +#include // typedefs typedef double float_point_t; +// declare mathematical functions +#define UNARY(f) float_point_t f(float_point_t x) +#define BINARY(f) float_point_t f(float_point_t x, float_point_t y) + +// unary functions +UNARY(fabs); +UNARY(acos); +UNARY(asin); +UNARY(atan); +UNARY(cos); +UNARY(exp); +UNARY(log); +UNARY(sin); +UNARY(sqrt); +UNARY(tan); + +// binary functions +BINARY(pow); + // externals // azmul float_point_t azmul(float_point_t x, float_point_t y) diff --git a/src/pyoptinterface/_src/jit_c.py b/src/pyoptinterface/_src/jit_c.py index 7c9c391..e6a6251 100644 --- a/src/pyoptinterface/_src/jit_c.py +++ b/src/pyoptinterface/_src/jit_c.py @@ -20,11 +20,11 @@ # On Linux/Mac, tcc has lib/tcc/include/ and lib/tcc/libtcc1.a which must be included in compilation libtcc_extra_include_path = None libtcc_extra_lib_path = None -libtcc_extra_lib_name = None +libtcc_extra_lib_names = None if system in ["Linux", "Darwin"]: libtcc_extra_include_path = os.path.join(libtcc_dir, "tcc", "include") libtcc_extra_lib_path = os.path.join(libtcc_dir, "tcc") - libtcc_extra_lib_name = "libtcc1.a" + libtcc_extra_lib_names = [] # Define types TCCState = ctypes.c_void_p @@ -82,9 +82,10 @@ def create_state(self): == -1 ): raise Exception("Failed to add extra library path") - if libtcc_extra_lib_name: - if self.libtcc.tcc_add_library(state, libtcc_extra_lib_name.encode()) == -1: - raise Exception("Failed to add extra library") + if libtcc_extra_lib_names: + for name in libtcc_extra_lib_names: + if self.libtcc.tcc_add_library(state, name.encode()) == -1: + raise Exception("Failed to add extra library") self.states.append(state) diff --git a/tests/test_nlp.py b/tests/test_nlp.py index 83effed..c8e8e8b 100644 --- a/tests/test_nlp.py +++ b/tests/test_nlp.py @@ -157,6 +157,8 @@ def con(vars): if __name__ == "__main__": - test_ipopt(ipopt.Model) - test_nlp_param(ipopt.Model) - test_nlfunc_ifelse(ipopt.Model) + def c(): + return ipopt.Model(jit="C") + test_ipopt(c) + test_nlp_param(c) + test_nlfunc_ifelse(c)