diff --git a/CMakeLists.txt b/CMakeLists.txt index 377b371..51dc191 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,27 +57,6 @@ if(${PYBAMM_IDAKLU_EXPR_CASADI} STREQUAL "ON" ) ) endif() -# Check IREE build flag -if(NOT DEFINED PYBAMM_IDAKLU_EXPR_IREE) - set(PYBAMM_IDAKLU_EXPR_IREE OFF) -endif() -message("PYBAMM_IDAKLU_EXPR_IREE: ${PYBAMM_IDAKLU_EXPR_IREE}") - -# IREE (MLIR expression evaluation) PyBaMM source files -set(IDAKLU_EXPR_IREE_SOURCE_FILES "") -if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) - add_compile_definitions(IREE_ENABLE) - # Source file list - set(IDAKLU_EXPR_IREE_SOURCE_FILES - src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp - src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.hpp - src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp - src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.hpp - src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp - src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp - ) -endif() - # The complete (all dependencies) sources list should be mirrored in setup.py pybind11_add_module(idaklu # pybind11 interface @@ -113,7 +92,6 @@ pybind11_add_module(idaklu src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionTypes.hpp # IDAKLU expressions - concrete implementations ${IDAKLU_EXPR_CASADI_SOURCE_FILES} - ${IDAKLU_EXPR_IREE_SOURCE_FILES} ) if (NOT DEFINED USE_PYTHON_CASADI) @@ -184,16 +162,3 @@ else() endif() include_directories(${SuiteSparse_INCLUDE_DIRS}) target_link_libraries(idaklu PRIVATE ${SuiteSparse_LIBRARIES}) - -# IREE (MLIR compiler and runtime library) build settings -if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) - set(IREE_BUILD_COMPILER ON) - set(IREE_BUILD_TESTS OFF) - set(IREE_BUILD_SAMPLES OFF) - add_subdirectory(iree EXCLUDE_FROM_ALL) - set(IREE_COMPILER_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler") - target_include_directories(idaklu SYSTEM PRIVATE "${IREE_COMPILER_ROOT}/bindings/c/iree/compiler") - target_compile_options(idaklu PRIVATE ${IREE_DEFAULT_COPTS}) - target_link_libraries(idaklu PRIVATE iree_compiler_bindings_c_loader) - target_link_libraries(idaklu PRIVATE iree_runtime_runtime) -endif() diff --git a/noxfile.py b/noxfile.py index 6ebd6e6..0efac38 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,8 +1,6 @@ import nox import os import sys -import warnings -import platform from pathlib import Path @@ -13,44 +11,6 @@ else: nox.options.sessions = ["unit"] - -def set_iree_state(): - """ - Check if IREE is enabled and set the environment variable accordingly. - - Returns - ------- - str - "ON" if IREE is enabled, "OFF" otherwise. - - """ - state = "ON" if os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") == "ON" else "OFF" - if state == "ON": - if sys.platform == "win32": - warnings.warn( - ( - "IREE is not enabled on Windows yet. " - "Setting PYBAMM_IDAKLU_EXPR_IREE=OFF." - ), - stacklevel=2, - ) - return "OFF" - if sys.platform == "darwin": - # iree-compiler is currently only available as a wheel on macOS 13 (or - # higher) and Python version 3.11 - mac_ver = int(platform.mac_ver()[0].split(".")[0]) - if (not sys.version_info[:2] == (3, 11)) or mac_ver < 13: - warnings.warn( - ( - "IREE is only supported on MacOS 13 (or higher) and Python" - "version 3.11. Setting PYBAMM_IDAKLU_EXPR_IREE=OFF." - ), - stacklevel=2, - ) - return "OFF" - return state - - homedir = Path(__file__) PYBAMM_ENV = { "LD_LIBRARY_PATH": f"{homedir}/.idaklu/lib", @@ -58,10 +18,6 @@ def set_iree_state(): "MPLBACKEND": "Agg", # Expression evaluators (...EXPR_CASADI cannot be fully disabled at this time) "PYBAMM_IDAKLU_EXPR_CASADI": os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON"), - "PYBAMM_IDAKLU_EXPR_IREE": set_iree_state(), - "IREE_INDEX_URL": os.getenv( - "IREE_INDEX_URL", "https://iree.dev/pip-release-links.html" - ), } VENV_DIR = Path("./venv").resolve() @@ -88,29 +44,6 @@ def run_pybamm_requires(session): set_environment_variables(PYBAMM_ENV, session=session) if sys.platform != "win32": session.run("python", "install_KLU_Sundials.py", *session.posargs) - if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON" and not os.path.exists( - "./iree" - ): - session.run( - "git", - "clone", - "--depth=1", - "--recurse-submodules", - "--shallow-submodules", - "--branch=candidate-20240507.886", - "https://github.com/openxla/iree", - "iree/", - external=True, - ) - with session.chdir("iree"): - session.run( - "git", - "submodule", - "update", - "--init", - "--recursive", - external=True, - ) else: session.error("nox -s idaklu-requires is only available on Linux & macOS.") @@ -122,13 +55,4 @@ def run_unit(session): session.install("setuptools", silent=False) session.install("casadi", silent=False) session.install("-e", ".[dev]", silent=False) - if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": - # See comments in 'dev' session - session.install( - "-e", - ".[iree]", - "--find-links", - PYBAMM_ENV.get("IREE_INDEX_URL"), - silent=False, - ) session.run("pytest", "tests") diff --git a/setup.py b/setup.py index 65e3e6c..5075675 100644 --- a/setup.py +++ b/setup.py @@ -92,13 +92,11 @@ def run(self): build_type = os.getenv("PYBAMM_CPP_BUILD_TYPE", "RELEASE") idaklu_expr_casadi = os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON") - idaklu_expr_iree = os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") cmake_args = [ f"-DCMAKE_BUILD_TYPE={build_type}", f"-DPYTHON_EXECUTABLE={sys.executable}", "-DUSE_PYTHON_CASADI={}".format("TRUE" if use_python_casadi else "FALSE"), f"-DPYBAMM_IDAKLU_EXPR_CASADI={idaklu_expr_casadi}", - f"-DPYBAMM_IDAKLU_EXPR_IREE={idaklu_expr_iree}", ] if self.suitesparse_root: cmake_args.append( @@ -244,14 +242,6 @@ def run(self): "src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionSparsity.hpp", "src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.cpp", "src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.hpp", - "src/pybammsolvers/idaklu_source/Expressions/IREE/IREEBaseFunction.hpp", - "src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp", - "src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp", - "src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.hpp", - "src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp", - "src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.hpp", - "src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp", - "src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp", "src/pybammsolvers/idaklu_source/idaklu_solver.hpp", "src/pybammsolvers/idaklu_source/IDAKLUSolver.cpp", "src/pybammsolvers/idaklu_source/IDAKLUSolver.hpp", diff --git a/src/pybammsolvers/idaklu.cpp b/src/pybammsolvers/idaklu.cpp index b4ebc13..313ab0b 100644 --- a/src/pybammsolvers/idaklu.cpp +++ b/src/pybammsolvers/idaklu.cpp @@ -15,10 +15,6 @@ #include "idaklu_source/common.hpp" #include "idaklu_source/Expressions/Casadi/CasadiFunctions.hpp" -#ifdef IREE_ENABLE -#include "idaklu_source/Expressions/IREE/IREEFunctions.hpp" -#endif - casadi::Function generate_casadi_function(const std::string &data) { @@ -96,34 +92,6 @@ PYBIND11_MODULE(idaklu, m) py::arg("shape"), py::return_value_policy::take_ownership); -#ifdef IREE_ENABLE - m.def("create_iree_solver_group", &create_idaklu_solver_group, - "Create a group of iree idaklu solver objects", - py::arg("number_of_states"), - py::arg("number_of_parameters"), - py::arg("rhs_alg"), - py::arg("jac_times_cjmass"), - py::arg("jac_times_cjmass_colptrs"), - py::arg("jac_times_cjmass_rowvals"), - py::arg("jac_times_cjmass_nnz"), - py::arg("jac_bandwidth_lower"), - py::arg("jac_bandwidth_upper"), - py::arg("jac_action"), - py::arg("mass_action"), - py::arg("sens"), - py::arg("events"), - py::arg("number_of_events"), - py::arg("rhs_alg_id"), - py::arg("atol"), - py::arg("rtol"), - py::arg("inputs"), - py::arg("var_fcns"), - py::arg("dvar_dy_fcns"), - py::arg("dvar_dp_fcns"), - py::arg("options"), - py::return_value_policy::take_ownership); -#endif - m.def("generate_function", &generate_casadi_function, "Generate a casadi function", py::arg("string"), @@ -174,20 +142,6 @@ PYBIND11_MODULE(idaklu, m) py::class_(m, "Function"); -#ifdef IREE_ENABLE - py::class_(m, "IREEBaseFunctionType") - .def(py::init<>()) - .def_readwrite("mlir", &IREEBaseFunctionType::mlir) - .def_readwrite("kept_var_idx", &IREEBaseFunctionType::kept_var_idx) - .def_readwrite("nnz", &IREEBaseFunctionType::nnz) - .def_readwrite("numel", &IREEBaseFunctionType::numel) - .def_readwrite("col", &IREEBaseFunctionType::col) - .def_readwrite("row", &IREEBaseFunctionType::row) - .def_readwrite("pytree_shape", &IREEBaseFunctionType::pytree_shape) - .def_readwrite("pytree_sizes", &IREEBaseFunctionType::pytree_sizes) - .def_readwrite("n_args", &IREEBaseFunctionType::n_args); -#endif - py::class_(m, "solution") .def_readwrite("t", &Solution::t) .def_readwrite("y", &Solution::y) diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEBaseFunction.hpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEBaseFunction.hpp deleted file mode 100644 index d2ba7e4..0000000 --- a/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEBaseFunction.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP -#define PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP - -#include -#include - -/* - * @brief Function definition passed from PyBaMM - */ -class IREEBaseFunctionType -{ -public: // methods - const std::string& get_mlir() const { return mlir; } - -public: // data members - std::string mlir; // cppcheck-suppress unusedStructMember - std::vector kept_var_idx; // cppcheck-suppress unusedStructMember - expr_int nnz; // cppcheck-suppress unusedStructMember - expr_int numel; // cppcheck-suppress unusedStructMember - std::vector col; // cppcheck-suppress unusedStructMember - std::vector row; // cppcheck-suppress unusedStructMember - std::vector pytree_shape; // cppcheck-suppress unusedStructMember - std::vector pytree_sizes; // cppcheck-suppress unusedStructMember - expr_int n_args; // cppcheck-suppress unusedStructMember -}; - -#endif // PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp deleted file mode 100644 index bcdae5e..0000000 --- a/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef PYBAMM_IDAKLU_IREE_FUNCTION_HPP -#define PYBAMM_IDAKLU_IREE_FUNCTION_HPP - -#include "../../Options.hpp" -#include "../Expressions.hpp" -#include -#include "iree_jit.hpp" -#include "IREEBaseFunction.hpp" - -/** - * @brief Class for handling individual iree functions - */ -class IREEFunction : public Expression -{ -public: - typedef IREEBaseFunctionType BaseFunctionType; - - /* - * @brief Constructor - */ - explicit IREEFunction(const BaseFunctionType &f); - - // Method overrides - void operator()() override; - void operator()(const std::vector& inputs, - const std::vector& results) override; - expr_int out_shape(int k) override; - expr_int nnz() override; - expr_int nnz_out() override; - const std::vector& get_col() override; - const std::vector& get_row() override; - - /* - * @brief Evaluate the MLIR function - */ - void evaluate(); - - /* - * @brief Evaluate the MLIR function - * @param n_outputs The number of outputs to return - */ - void evaluate(int n_outputs); - -public: - std::unique_ptr session; - std::vector> result; // cppcheck-suppress unusedStructMember - std::vector> input_shape; // cppcheck-suppress unusedStructMember - std::vector> output_shape; // cppcheck-suppress unusedStructMember - std::vector> input_data; // cppcheck-suppress unusedStructMember - - BaseFunctionType m_func; // cppcheck-suppress unusedStructMember - std::string module_name; // cppcheck-suppress unusedStructMember - std::string function_name; // cppcheck-suppress unusedStructMember - std::vector m_arg_argno; // cppcheck-suppress unusedStructMember - std::vector m_arg_argix; // cppcheck-suppress unusedStructMember - std::vector numel; // cppcheck-suppress unusedStructMember -}; - -#endif // PYBAMM_IDAKLU_IREE_FUNCTION_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp deleted file mode 100644 index 3bde647..0000000 --- a/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp +++ /dev/null @@ -1,230 +0,0 @@ -#include -#include -#include -#include -#include - -#include "IREEFunctions.hpp" -#include "iree_jit.hpp" -#include "ModuleParser.hpp" - -IREEFunction::IREEFunction(const BaseFunctionType &f) : Expression(), m_func(f) -{ - DEBUG("IreeFunction constructor"); - const std::string& mlir = f.get_mlir(); - - // Parse IREE (MLIR) function string - if (mlir.size() == 0) { - DEBUG("Empty function --- skipping..."); - return; - } - - // Parse MLIR for module name, input and output shapes - ModuleParser parser(mlir); - module_name = parser.getModuleName(); - function_name = parser.getFunctionName(); - input_shape = parser.getInputShape(); - output_shape = parser.getOutputShape(); - - DEBUG("Compiling module: '" << module_name << "'"); - const char* device_uri = "local-sync"; - session = std::make_unique(device_uri, mlir); - DEBUG("compile complete."); - // Create index vectors into m_arg - // This is required since Jax expands input arguments through PyTrees, which need to - // be remapped to the corresponding expression call. For example: - // fcn(t, y, inputs, cj) with inputs = [[in1], [in2], [in3]] - // will produce a function with six inputs; we therefore need to be able to map - // arguments to their 1) corresponding input argument, and 2) the correct position - // within that argument. - m_arg_argno.clear(); - m_arg_argix.clear(); - int current_element = 0; - for (int i=0; i 2) || - ((input_shape[j].size() == 2) && (input_shape[j][1] > 1)) - ) { - std::cerr << "Unsupported input shape: " << input_shape[j].size() << " ["; - for (int k=0; k {res0} signature (i.e. x and z are reduced out) - // with kept_var_idx = [1] - // - // *********************************************************************************** - - DEBUG("Copying inputs, shape " << input_shape.size() << " - " << m_func.kept_var_idx.size()); - for (int j=0; j 1) { - // Index into argument using appropriate shape - for(int k=0; k(m_arg[m_arg_from][m_arg_argix[mlir_arg]+k]); - } - } else { - // Copy the entire vector - for(int k=0; k(m_arg[m_arg_from][k]); - } - } - } - - // Call the 'main' function of the module - const std::string mlir = m_func.get_mlir(); - DEBUG("Calling function '" << function_name << "'"); - auto status = session->iree_runtime_exec(function_name, input_shape, input_data, result); - if (!iree_status_is_ok(status)) { - iree_status_fprint(stderr, status); - std::cerr << "MLIR: " << mlir.substr(0,1000) << std::endl; - throw std::runtime_error("Execution failed"); - } - - // Copy results to output array - for(size_t k=0; k(result[k][j]); - } - } - - DEBUG("IreeFunction operator() complete"); -} - -expr_int IREEFunction::out_shape(int k) { - DEBUG("IreeFunction nnz(" << k << "): " << m_func.nnz); - auto elements = 1; - for (auto i : output_shape[k]) { - elements *= i; - } - return elements; -} - -expr_int IREEFunction::nnz() { - DEBUG("IreeFunction nnz: " << m_func.nnz); - return nnz_out(); -} - -expr_int IREEFunction::nnz_out() { - DEBUG("IreeFunction nnz_out" << m_func.nnz); - return m_func.nnz; -} - -const std::vector& IREEFunction::get_row() { - DEBUG("IreeFunction get_row" << m_func.row.size()); - return m_func.row; -} - -const std::vector& IREEFunction::get_col() { - DEBUG("IreeFunction get_col" << m_func.col.size()); - return m_func.col; -} - -void IREEFunction::operator()(const std::vector& inputs, - const std::vector& results) -{ - DEBUG("IreeFunction operator() with inputs and results"); - // Set-up input arguments, provide result vector, then execute function - // Example call: fcn({in1, in2, in3}, {out1}) - ASSERT(inputs.size() == m_func.n_args); - for(size_t k=0; k -#include "iree_jit.hpp" -#include "IREEFunction.hpp" - -/** - * @brief Class for handling iree functions - */ -class IREEFunctions : public ExpressionSet -{ -public: - std::unique_ptr iree_compiler; - - typedef IREEFunction::BaseFunctionType BaseFunctionType; // expose typedef in class - - int iree_init_status; - - int iree_init(const std::string& device_uri, const std::string& target_backends) { - // Initialise IREE - DEBUG("IREEFunctions: Initialising IREECompiler"); - iree_compiler = std::make_unique(device_uri.c_str()); - - int iree_argc = 2; - std::string target_backends_str = "--iree-hal-target-backends=" + target_backends; - const char* iree_argv[2] = {"iree", target_backends_str.c_str()}; - iree_compiler->init(iree_argc, iree_argv); - DEBUG("IREEFunctions: Initialised IREECompiler"); - return 0; - } - - int iree_init() { - return iree_init("local-sync", "llvm-cpu"); - } - - - /** - * @brief Create a new IREEFunctions object - */ - IREEFunctions( - const BaseFunctionType &rhs_alg, - const BaseFunctionType &jac_times_cjmass, - const int jac_times_cjmass_nnz, - const int jac_bandwidth_lower, - const int jac_bandwidth_upper, - const np_array_int &jac_times_cjmass_rowvals_arg, - const np_array_int &jac_times_cjmass_colptrs_arg, - const int inputs_length, - const BaseFunctionType &jac_action, - const BaseFunctionType &mass_action, - const BaseFunctionType &sens, - const BaseFunctionType &events, - const int n_s, - const int n_e, - const int n_p, - const std::vector& var_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - const SetupOptions& setup_opts - ) : - iree_init_status(iree_init()), - rhs_alg_iree(rhs_alg), - jac_times_cjmass_iree(jac_times_cjmass), - jac_action_iree(jac_action), - mass_action_iree(mass_action), - sens_iree(sens), - events_iree(events), - ExpressionSet( - static_cast(&rhs_alg_iree), - static_cast(&jac_times_cjmass_iree), - jac_times_cjmass_nnz, - jac_bandwidth_lower, - jac_bandwidth_upper, - jac_times_cjmass_rowvals_arg, - jac_times_cjmass_colptrs_arg, - inputs_length, - static_cast(&jac_action_iree), - static_cast(&mass_action_iree), - static_cast(&sens_iree), - static_cast(&events_iree), - n_s, n_e, n_p, - setup_opts) - { - // convert BaseFunctionType list to IREEFunction list - // NOTE: You must allocate ALL std::vector elements before taking references - for (auto& var : var_fcns) - var_fcns_iree.push_back(IREEFunction(*var)); - for (int k = 0; k < var_fcns_iree.size(); k++) - ExpressionSet::var_fcns.push_back(&this->var_fcns_iree[k]); - - for (auto& var : dvar_dy_fcns) - dvar_dy_fcns_iree.push_back(IREEFunction(*var)); - for (int k = 0; k < dvar_dy_fcns_iree.size(); k++) - this->dvar_dy_fcns.push_back(&this->dvar_dy_fcns_iree[k]); - - for (auto& var : dvar_dp_fcns) - dvar_dp_fcns_iree.push_back(IREEFunction(*var)); - for (int k = 0; k < dvar_dp_fcns_iree.size(); k++) - this->dvar_dp_fcns.push_back(&this->dvar_dp_fcns_iree[k]); - - // copy across numpy array values - const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; - auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); - jac_times_cjmass_rowvals.resize(n_row_vals); - for (int i = 0; i < n_row_vals; i++) { - jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; - } - - const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; - auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); - jac_times_cjmass_colptrs.resize(n_col_ptrs); - for (int i = 0; i < n_col_ptrs; i++) { - jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; - } - - inputs.resize(inputs_length); - } - - IREEFunction rhs_alg_iree; - IREEFunction jac_times_cjmass_iree; - IREEFunction jac_action_iree; - IREEFunction mass_action_iree; - IREEFunction sens_iree; - IREEFunction events_iree; - - std::vector var_fcns_iree; - std::vector dvar_dy_fcns_iree; - std::vector dvar_dp_fcns_iree; - - realtype* get_tmp_state_vector() override { - return tmp_state_vector.data(); - } - realtype* get_tmp_sparse_jacobian_data() override { - return tmp_sparse_jacobian_data.data(); - } - - ~IREEFunctions() { - // cleanup IREE - iree_compiler->cleanup(); - } -}; - -#endif // PYBAMM_IDAKLU_IREE_FUNCTIONS_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp deleted file mode 100644 index d1c5575..0000000 --- a/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp +++ /dev/null @@ -1,91 +0,0 @@ -#include "ModuleParser.hpp" - -ModuleParser::ModuleParser(const std::string& mlir) : mlir(mlir) -{ - parse(); -} - -void ModuleParser::parse() -{ - // Parse module name - std::regex module_name_regex("module @([^\\s]+)"); // Match until first whitespace - std::smatch module_name_match; - std::regex_search(this->mlir, module_name_match, module_name_regex); - if (module_name_match.size() == 0) { - std::cerr << "Could not find module name in module" << std::endl; - std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; - throw std::runtime_error("Could not find module name in module"); - } - module_name = module_name_match[1].str(); - DEBUG("Module name: " << module_name); - - // Assign function name - function_name = module_name + ".main"; - - // Isolate 'main' function call signature - std::regex main_func("public @main\\((.*?)\\) -> \\((.*?)\\)"); - std::smatch match; - std::regex_search(this->mlir, match, main_func); - if (match.size() == 0) { - std::cerr << "Could not find 'main' function in module" << std::endl; - std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; - throw std::runtime_error("Could not find 'main' function in module"); - } - std::string main_sig_inputs = match[1].str(); - std::string main_sig_outputs = match[2].str(); - DEBUG( - "Main function signature: " << main_sig_inputs << " -> " << main_sig_outputs << '\n' - ); - - // Parse input sizes - input_shape.clear(); - std::regex input_size("tensor<(.*?)>"); - for(std::sregex_iterator i = std::sregex_iterator(main_sig_inputs.begin(), main_sig_inputs.end(), input_size); - i != std::sregex_iterator(); - ++i) - { - std::smatch matchi = *i; - std::string match_str = matchi.str(); - std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string - std::vector shape; - std::string dim_str; - for (char c : shape_str) { - if (c == 'x') { - shape.push_back(std::stoi(dim_str)); - dim_str = ""; - } else { - dim_str += c; - } - } - input_shape.push_back(shape); - } - - // Parse output sizes - output_shape.clear(); - std::regex output_size("tensor<(.*?)>"); - for( - std::sregex_iterator i = std::sregex_iterator(main_sig_outputs.begin(), main_sig_outputs.end(), output_size); - i != std::sregex_iterator(); - ++i - ) { - std::smatch matchi = *i; - std::string match_str = matchi.str(); - std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string - std::vector shape; - std::string dim_str; - for (char c : shape_str) { - if (c == 'x') { - shape.push_back(std::stoi(dim_str)); - dim_str = ""; - } else { - dim_str += c; - } - } - // If shape is empty, assume scalar (i.e. "tensor" or some singleton variant) - if (shape.size() == 0) { - shape.push_back(1); - } - // Add output to list - output_shape.push_back(shape); - } -} diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp deleted file mode 100644 index 2fbfdc0..0000000 --- a/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp +++ /dev/null @@ -1,55 +0,0 @@ -#ifndef PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP -#define PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP - -#include -#include -#include -#include -#include - -#include "../../common.hpp" - -class ModuleParser { -private: - std::string mlir; // cppcheck-suppress unusedStructMember - // codacy fix: member is referenced as this->mlir in parse() - std::string module_name; - std::string function_name; - std::vector> input_shape; - std::vector> output_shape; -public: - /** - * @brief Constructor - * @param mlir: string representation of MLIR code for the module - */ - explicit ModuleParser(const std::string& mlir); - - /** - * @brief Get the module name - * @return module name - */ - const std::string& getModuleName() const { return module_name; } - - /** - * @brief Get the function name - * @return function name - */ - const std::string& getFunctionName() const { return function_name; } - - /** - * @brief Get the input shape - * @return input shape - */ - const std::vector>& getInputShape() const { return input_shape; } - - /** - * @brief Get the output shape - * @return output shape - */ - const std::vector>& getOutputShape() const { return output_shape; } - -private: - void parse(); -}; - -#endif // PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp deleted file mode 100644 index c84c392..0000000 --- a/src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp +++ /dev/null @@ -1,408 +0,0 @@ -#include "iree_jit.hpp" -#include "iree/hal/buffer_view.h" -#include "iree/hal/buffer_view_util.h" -#include "../../common.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -// Used to suppress stderr output (see initIREE below) -#ifdef _WIN32 -#include -#define close _close -#define dup _dup -#define fileno _fileno -#define open _open -#define dup2 _dup2 -#define NULL_DEVICE "NUL" -#else -#define NULL_DEVICE "/dev/null" -#endif - -void IREESession::handle_compiler_error(iree_compiler_error_t *error) { - const char *msg = ireeCompilerErrorGetMessage(error); - fprintf(stderr, "Error from compiler API:\n%s\n", msg); - ireeCompilerErrorDestroy(error); -} - -void IREESession::cleanup_compiler_state(compiler_state_t s) { - if (s.inv) - ireeCompilerInvocationDestroy(s.inv); - if (s.output) - ireeCompilerOutputDestroy(s.output); - if (s.source) - ireeCompilerSourceDestroy(s.source); - if (s.session) - ireeCompilerSessionDestroy(s.session); -} - -IREECompiler::IREECompiler() { - this->device_uri = "local-sync"; -}; - -IREECompiler::~IREECompiler() { - ireeCompilerGlobalShutdown(); -}; - -int IREECompiler::init(int argc, const char **argv) { - return initIREE(argc, argv); // Initialisation and version checking -}; - -int IREECompiler::cleanup() { - return 0; -}; - -IREESession::IREESession() { - s.session = NULL; - s.source = NULL; - s.output = NULL; - s.inv = NULL; -}; - -IREESession::IREESession(const char *device_uri, const std::string& mlir_code) : IREESession() { - this->device_uri=device_uri; - this->mlir_code=mlir_code; - init(); -} - -int IREESession::init() { - if (initCompiler() != 0) // Prepare compiler inputs and outputs - return 1; - if (initCompileToByteCode() != 0) // Compile to bytecode - return 1; - if (initRuntime() != 0) // Initialise runtime environment - return 1; - return 0; -}; - -int IREECompiler::initIREE(int argc, const char **argv) { - - if (device_uri == NULL) { - DEBUG("No device URI provided, using local-sync\n"); - this->device_uri = "local-sync"; - } - - int cl_argc = argc; - const char *iree_compiler_lib = std::getenv("IREE_COMPILER_LIB"); - - // Load the compiler library and initialize it - // NOTE: On second and subsequent calls, the function will return false and display - // a message on stderr, but it is safe to ignore this message. For an improved user - // experience we actively suppress stderr during the call to this function but since - // this also suppresses any other error message, we actively check for the presence - // of the library file prior to the call. - - // Check if the library file exists - if (iree_compiler_lib == NULL) { - fprintf(stderr, "Error: IREE_COMPILER_LIB environment variable not set\n"); - return 1; - } - if (access(iree_compiler_lib, F_OK) == -1) { - fprintf(stderr, "Error: IREE_COMPILER_LIB file not found\n"); - return 1; - } - // Suppress stderr - int saved_stderr = dup(fileno(stderr)); - if (!freopen(NULL_DEVICE, "w", stderr)) - DEBUG("Error: failed redirecting stderr"); - // Load library - bool result = ireeCompilerLoadLibrary(iree_compiler_lib); - // Restore stderr - fflush(stderr); - dup2(saved_stderr, fileno(stderr)); - close(saved_stderr); - // Process result - if (!result) { - // Library may have already been loaded (can be safely ignored), - // or may not be found (critical error), we cannot tell which from the return value. - return 1; - } - // Must be balanced with a call to ireeCompilerGlobalShutdown() - ireeCompilerGlobalInitialize(); - - // To set global options (see `iree-compile --help` for possibilities), use - // |ireeCompilerGetProcessCLArgs| and |ireeCompilerSetupGlobalCL| - ireeCompilerGetProcessCLArgs(&cl_argc, &argv); - ireeCompilerSetupGlobalCL(cl_argc, argv, "iree-jit", false); - - // Check the API version before proceeding any further - uint32_t api_version = (uint32_t)ireeCompilerGetAPIVersion(); - uint16_t api_version_major = (uint16_t)((api_version >> 16) & 0xFFFFUL); - uint16_t api_version_minor = (uint16_t)(api_version & 0xFFFFUL); - DEBUG("Compiler API version: " << api_version_major << "." << api_version_minor); - if (api_version_major > IREE_COMPILER_EXPECTED_API_MAJOR || - api_version_minor < IREE_COMPILER_EXPECTED_API_MINOR) { - fprintf(stderr, - "Error: incompatible API version; built for version %" PRIu16 - ".%" PRIu16 " but loaded version %" PRIu16 ".%" PRIu16 "\n", - IREE_COMPILER_EXPECTED_API_MAJOR, IREE_COMPILER_EXPECTED_API_MINOR, - api_version_major, api_version_minor); - ireeCompilerGlobalShutdown(); - return 1; - } - - // Check for a build tag with release version information - const char *revision = ireeCompilerGetRevision(); // cppcheck-suppress unreadVariable - DEBUG("Compiler revision: '" << revision << "'"); - return 0; -}; - -int IREESession::initCompiler() { - - // A session provides a scope where one or more invocations can be executed - s.session = ireeCompilerSessionCreate(); - - // Read the MLIR from memory - error = ireeCompilerSourceWrapBuffer( - s.session, - "expr_buffer", // name of the buffer (does not need to match MLIR) - mlir_code.c_str(), - mlir_code.length() + 1, - true, - &s.source - ); - if (error) { - fprintf(stderr, "Error wrapping source buffer\n"); - handle_compiler_error(error); - cleanup_compiler_state(s); - return 1; - } - DEBUG("Wrapped buffer as a compiler source"); - - return 0; -}; - -int IREESession::initCompileToByteCode() { - // Use an invocation to compile from the input source to the output stream - iree_compiler_invocation_t *inv = ireeCompilerInvocationCreate(s.session); - ireeCompilerInvocationEnableConsoleDiagnostics(inv); - - if (!ireeCompilerInvocationParseSource(inv, s.source)) { - fprintf(stderr, "Error parsing input source into invocation\n"); - cleanup_compiler_state(s); - return 1; - } - - // Compile, specifying the target dialect phase - ireeCompilerInvocationSetCompileToPhase(inv, "end"); - - // Run the compiler invocation pipeline - if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) { - fprintf(stderr, "Error running compiler invocation\n"); - cleanup_compiler_state(s); - return 1; - } - DEBUG("Compilation successful"); - - // Create compiler 'output' to a memory buffer - error = ireeCompilerOutputOpenMembuffer(&s.output); - if (error) { - fprintf(stderr, "Error opening output membuffer\n"); - handle_compiler_error(error); - cleanup_compiler_state(s); - return 1; - } - - // Create bytecode in memory - error = ireeCompilerInvocationOutputVMBytecode(inv, s.output); - if (error) { - fprintf(stderr, "Error creating VM bytecode\n"); - handle_compiler_error(error); - cleanup_compiler_state(s); - return 1; - } - - // Once the bytecode has been written, retrieve the memory map - ireeCompilerOutputMapMemory(s.output, &contents, &size); - - return 0; -}; - -int IREESession::initRuntime() { - // Setup the shared runtime instance - iree_runtime_instance_options_t instance_options; - iree_runtime_instance_options_initialize(&instance_options); - iree_runtime_instance_options_use_all_available_drivers(&instance_options); - status = iree_runtime_instance_create( - &instance_options, iree_allocator_system(), &instance); - - // Create the HAL device used to run the workloads - if (iree_status_is_ok(status)) { - status = iree_hal_create_device( - iree_runtime_instance_driver_registry(instance), - iree_make_cstring_view(device_uri), - iree_runtime_instance_host_allocator(instance), &device); - } - - // Set up the session to run the module - if (iree_status_is_ok(status)) { - iree_runtime_session_options_t session_options; - iree_runtime_session_options_initialize(&session_options); - status = iree_runtime_session_create_with_device( - instance, &session_options, device, - iree_runtime_instance_host_allocator(instance), &session); - } - - // Load the compiled user module from a file - if (iree_status_is_ok(status)) { - /*status = iree_runtime_session_append_bytecode_module_from_file(session, module_path);*/ - status = iree_runtime_session_append_bytecode_module_from_memory( - session, - iree_make_const_byte_span(contents, size), - iree_allocator_null()); - } - - if (!iree_status_is_ok(status)) - return 1; - - return 0; -}; - -// Release the session and free all cached resources. -int IREESession::cleanup() { - iree_runtime_session_release(session); - iree_hal_device_release(device); - iree_runtime_instance_release(instance); - - int ret = (int)iree_status_code(status); - if (!iree_status_is_ok(status)) { - iree_status_fprint(stderr, status); - iree_status_ignore(status); - } - cleanup_compiler_state(s); - return ret; -} - -iree_status_t IREESession::iree_runtime_exec( - const std::string& function_name, - const std::vector>& inputs, - const std::vector>& data, - std::vector>& result -) { - - // Initialize the call to the function. - status = iree_runtime_call_initialize_by_name( - session, iree_make_cstring_view(function_name.c_str()), &call); - if (!iree_status_is_ok(status)) { - std::cerr << "Error: iree_runtime_call_initialize_by_name failed" << std::endl; - iree_status_fprint(stderr, status); - return status; - } - - // Append the function inputs with the HAL device allocator in use by the - // session. The buffers will be usable within the session and _may_ be usable - // in other sessions depending on whether they share a compatible device. - iree_hal_allocator_t* device_allocator = - iree_runtime_session_device_allocator(session); - host_allocator = iree_runtime_session_host_allocator(session); - status = iree_ok_status(); - if (iree_status_is_ok(status)) { - - for(int k=0; k arg_shape(input_shape.size()); - for (int i = 0; i < input_shape.size(); i++) { - arg_shape[i] = input_shape[i]; - } - int numel = 1; - for(int i = 0; i < input_shape.size(); i++) { - numel *= input_shape[i]; - } - std::vector arg_data(numel); - for(int i = 0; i < numel; i++) { - arg_data[i] = input_data[i]; - } - - status = iree_hal_buffer_view_allocate_buffer_copy( - device, device_allocator, - // Shape rank and dimensions: - arg_shape.size(), arg_shape.data(), - // Element type: - IREE_HAL_ELEMENT_TYPE_FLOAT_32, - // Encoding type: - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - (iree_hal_buffer_params_t){ - // Intended usage of the buffer (transfers, dispatches, etc): - .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, - // Access to allow to this memory: - .access = IREE_HAL_MEMORY_ACCESS_ALL, - // Where to allocate (host or device): - .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, - }, - // The actual heap buffer to wrap or clone and its allocator: - iree_make_const_byte_span(&arg_data[0], sizeof(float) * arg_data.size()), - // Buffer view + storage are returned and owned by the caller: - &arg); - } - if (iree_status_is_ok(status)) { - // Add to the call inputs list (which retains the buffer view). - status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg); - if (!iree_status_is_ok(status)) { - std::cerr << "Error: iree_runtime_call_inputs_push_back_buffer_view failed" << std::endl; - iree_status_fprint(stderr, status); - } - } - // Since the call retains the buffer view we can release it here. - iree_hal_buffer_view_release(arg); - } - } - - // Synchronously perform the call. - if (iree_status_is_ok(status)) { - status = iree_runtime_call_invoke(&call, /*flags=*/0); - } - if (!iree_status_is_ok(status)) { - std::cerr << "Error: iree_runtime_call_invoke failed" << std::endl; - iree_status_fprint(stderr, status); - } - - for(int k=0; k -#include -#include -#include - -#include -#include -#include - -#define IREE_COMPILER_EXPECTED_API_MAJOR 1 // At most this major version -#define IREE_COMPILER_EXPECTED_API_MINOR 2 // At least this minor version - -// Forward declaration -class IREESession; - -/* - * @brief IREECompiler class - * @details This class is used to compile MLIR code to IREE bytecode and - * create IREE sessions. - */ -class IREECompiler { -private: - /* - * @brief Device Uniform Resource Identifier (URI) - * @details The device URI is used to specify the device to be used by the - * IREE runtime. E.g. "local-sync" for CPU, "vulkan" for GPU, etc. - */ - const char *device_uri = NULL; - -private: - /* - * @brief Initialize the IREE runtime - */ - int initIREE(int argc, const char **argv); - -public: - /* - * @brief Default constructor - */ - IREECompiler(); - - /* - * @brief Destructor - */ - ~IREECompiler(); - - /* - * @brief Constructor with device URI - * @param device_uri Device URI - */ - explicit IREECompiler(const char *device_uri) - : IREECompiler() { this->device_uri=device_uri; } - - /* - * @brief Initialize the compiler - */ - int init(int argc, const char **argv); - - /* - * @brief Cleanup the compiler - * @details This method cleans up the compiler and all the IREE sessions - * created by the compiler. Returns 0 on success. - */ - int cleanup(); -}; - -/* - * @brief Compiler state - */ -typedef struct compiler_state_t { - iree_compiler_session_t *session; // cppcheck-suppress unusedStructMember - iree_compiler_source_t *source; // cppcheck-suppress unusedStructMember - iree_compiler_output_t *output; // cppcheck-suppress unusedStructMember - iree_compiler_invocation_t *inv; // cppcheck-suppress unusedStructMember -} compiler_state_t; - -/* - * @brief IREE session class - */ -class IREESession { -private: // data members - const char *device_uri = NULL; - compiler_state_t s; - iree_compiler_error_t *error = NULL; - void *contents = NULL; - uint64_t size = 0; - iree_runtime_session_t* session = NULL; - iree_status_t status; - iree_hal_device_t* device = NULL; - iree_runtime_instance_t* instance = NULL; - std::string mlir_code; // cppcheck-suppress unusedStructMember - iree_runtime_call_t call; - iree_allocator_t host_allocator; - -private: // private methods - void handle_compiler_error(iree_compiler_error_t *error); - void cleanup_compiler_state(compiler_state_t s); - int init(); - int initCompiler(); - int initCompileToByteCode(); - int initRuntime(); - -public: // public methods - - /* - * @brief Default constructor - */ - IREESession(); - - /* - * @brief Constructor with device URI and MLIR code - * @param device_uri Device URI - * @param mlir_code MLIR code - */ - explicit IREESession(const char *device_uri, const std::string& mlir_code); - - /* - * @brief Cleanup the IREE session - */ - int cleanup(); - - /* - * @brief Execute the pre-compiled byte-code with the given inputs - * @param function_name Function name to execute - * @param inputs List of input shapes - * @param data List of input data - * @param result List of output data - */ - iree_status_t iree_runtime_exec( - const std::string& function_name, - const std::vector>& inputs, - const std::vector>& data, - std::vector>& result - ); -}; - -#endif // IREE_JIT_HPP