diff --git a/CMakeLists.txt b/CMakeLists.txt index df16e23768..8e47a4bd76 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ project( # ============================================================================= # CMake common project settings # ============================================================================= -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) @@ -22,8 +22,11 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) # ============================================================================= # Build options for NMODL # ============================================================================= -option(NMODL_ENABLE_PYTHON_BINDINGS "Enable pybind11 based python bindings" ON) +option(NMODL_ENABLE_PYTHON_BINDINGS "Enable pybind11 based python bindings" OFF) option(NMODL_ENABLE_LEGACY_UNITS "Use original faraday, R, etc. instead of 2019 nist constants" OFF) +option(NMODL_ENABLE_LLVM "Enable LLVM based code generation" ON) +option(NMODL_ENABLE_JIT_EVENT_LISTENERS "Enable JITEventListener for Perf and Vtune" OFF) + if(NMODL_ENABLE_LEGACY_UNITS) add_definitions(-DUSE_LEGACY_UNITS) endif() @@ -140,6 +143,15 @@ find_python_module(sympy 1.2 REQUIRED) find_python_module(textwrap 0.9 REQUIRED) find_python_module(yaml 3.12 REQUIRED) +# ============================================================================= +# Find LLVM dependencies +# ============================================================================= +if(NMODL_ENABLE_LLVM) + include(LLVMHelper) + include_directories(${LLVM_INCLUDE_DIRS}) + add_definitions(-DNMODL_LLVM_BACKEND) +endif() + # ============================================================================= # Compiler specific flags for external submodules # ============================================================================= @@ -173,6 +185,7 @@ set(MEMORYCHECK_COMMAND_OPTIONS # do not enable tests if nmodl is used as submodule if(NOT NMODL_AS_SUBPROJECT) include(CTest) + add_subdirectory(test/benchmark) add_subdirectory(test/unit) endif() @@ -228,34 +241,40 @@ endif() message(STATUS "") message(STATUS "Configured NMODL ${PROJECT_VERSION} (${GIT_REVISION})") message(STATUS "") -string(TOLOWER "${CMAKE_GENERATOR}" cmake_generator_tolower) -if(cmake_generator_tolower MATCHES "makefile") - message(STATUS "Some things you can do now:") - message(STATUS "--------------------+--------------------------------------------------------") - message(STATUS "Command | Description") - message(STATUS "--------------------+--------------------------------------------------------") - message(STATUS "make | Build the project") - message(STATUS "make test | Run unit tests") - message(STATUS "make install | Will install NMODL to: ${CMAKE_INSTALL_PREFIX}") - message(STATUS "--------------------+--------------------------------------------------------") - message(STATUS " Build option | Status") - message(STATUS "--------------------+--------------------------------------------------------") - message(STATUS "CXX COMPILER | ${CMAKE_CXX_COMPILER}") - message(STATUS "COMPILE FLAGS | ${COMPILER_FLAGS}") - message(STATUS "Build Type | ${CMAKE_BUILD_TYPE}") - message(STATUS "Legacy Units | ${NMODL_ENABLE_LEGACY_UNITS}") - message(STATUS "Python Bindings | ${NMODL_ENABLE_PYTHON_BINDINGS}") - message(STATUS "Flex | ${FLEX_EXECUTABLE}") - message(STATUS "Bison | ${BISON_EXECUTABLE}") - message(STATUS "Python | ${PYTHON_EXECUTABLE}") - if(NMODL_CLANG_FORMAT) - message(STATUS "Clang Format | ${ClangFormat_EXECUTABLE}") - endif() - if(NMODL_CMAKE_FORMAT) - message(STATUS "Cmake Format | ${CMakeFormat_EXECUTABLE}") - endif() - message(STATUS "--------------+--------------------------------------------------------------") - message(STATUS " See documentation : https://github.com/BlueBrain/nmodl/") - message(STATUS "--------------+--------------------------------------------------------------") + +message(STATUS "Some things you can do now:") +message(STATUS "--------------------+--------------------------------------------------------") +message(STATUS "Command | Description") +message(STATUS "--------------------+--------------------------------------------------------") +message(STATUS "make | Build the project") +message(STATUS "make test | Run unit tests") +message(STATUS "make install | Will install NMODL to: ${CMAKE_INSTALL_PREFIX}") +message(STATUS "--------------------+--------------------------------------------------------") +message(STATUS " Build option | Status") +message(STATUS "--------------------+--------------------------------------------------------") +message(STATUS "CXX COMPILER | ${CMAKE_CXX_COMPILER}") +message(STATUS "COMPILE FLAGS | ${COMPILER_FLAGS}") +message(STATUS "Build Type | ${CMAKE_BUILD_TYPE}") +message(STATUS "Legacy Units | ${NMODL_ENABLE_LEGACY_UNITS}") +message(STATUS "Python Bindings | ${NMODL_ENABLE_PYTHON_BINDINGS}") +message(STATUS "Flex | ${FLEX_EXECUTABLE}") +message(STATUS "Bison | ${BISON_EXECUTABLE}") +message(STATUS "Python | ${PYTHON_EXECUTABLE}") +message(STATUS "LLVM Codegen | ${NMODL_ENABLE_LLVM}") +if(NMODL_ENABLE_LLVM) + message(STATUS " VERSION | ${LLVM_PACKAGE_VERSION}") + message(STATUS " INCLUDE | ${LLVM_INCLUDE_DIRS}") + message(STATUS " CMAKE | ${LLVM_CMAKE_DIR}") + message(STATUS " JIT LISTENERS | ${NMODL_ENABLE_JIT_EVENT_LISTENERS}") +endif() +if(NMODL_CLANG_FORMAT) + message(STATUS "Clang Format | ${ClangFormat_EXECUTABLE}") endif() +if(NMODL_CMAKE_FORMAT) + message(STATUS "Cmake Format | ${CMakeFormat_EXECUTABLE}") +endif() +message(STATUS "--------------+--------------------------------------------------------------") +message(STATUS " See documentation : https://github.com/BlueBrain/nmodl/") +message(STATUS "--------------+--------------------------------------------------------------") + message(STATUS "") diff --git a/INSTALL.md b/INSTALL.md index 335651c86c..32e9106669 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -31,7 +31,7 @@ Typically the versions of bison and flex provided by the system are outdated and To get recent version of all dependencies we recommend using [homebrew](https://brew.sh/): ```sh -brew install flex bison cmake python3 +brew install flex bison cmake python3 llvm ``` The necessary Python packages can then easily be added using the pip3 command. @@ -57,7 +57,7 @@ export PATH=/opt/homebrew/opt/flex/bin:/opt/homebrew/opt/bison/bin:$PATH On Ubuntu (>=18.04) flex/bison versions are recent enough and are installed along with the system toolchain: ```sh -apt-get install flex bison gcc python3 python3-pip +apt-get install flex bison gcc python3 python3-pip llvm-dev llvm-runtime llvm clang-format clang ``` The Python dependencies are installed using: @@ -79,6 +79,15 @@ cmake .. -DCMAKE_INSTALL_PREFIX=$HOME/nmodl make -j && make install ``` +If `llvm-config` is not in PATH then set LLVM_DIR as: + +```sh +cmake .. -DCMAKE_INSTALL_PREFIX=$HOME/nmodl -DLLVM_DIR=/path/to/llvm/install/lib/cmake/llvm + +# on OSX +cmake .. -DCMAKE_INSTALL_PREFIX=$HOME/nmodl -DLLVM_DIR=`brew --prefix llvm`/lib/cmake/llvm +``` + And set PYTHONPATH as: ```sh diff --git a/azure-pipelines.yml b/azure-pipelines.yml index f3a9d20722..59f5d5bb04 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -47,6 +47,11 @@ jobs: url="https://github.com/ispc/ispc/releases/download/${ispc_version}/ispc-${ispc_version}${ispc_version_suffix}-${url_os}.tar.gz"; mkdir $(pwd)/$CMAKE_PKG/ispc wget --output-document=- $url | tar -xvzf - -C $(pwd)/$CMAKE_PKG/ispc --strip 1; + # install llvm nightly (future v13) + wget https://apt.llvm.org/llvm.sh + chmod +x llvm.sh + sudo ./llvm.sh 13 + env: CMAKE_PKG: 'cmake-3.10.2-Linux-x86_64' displayName: 'Install Dependencies' @@ -56,7 +61,7 @@ jobs: mkdir -p $(Build.Repository.LocalPath)/build cd $(Build.Repository.LocalPath)/build cmake --version - cmake .. -DPYTHON_EXECUTABLE=$(which python3.7) -DCMAKE_INSTALL_PREFIX=$HOME/nmodl -DCMAKE_BUILD_TYPE=Release + cmake .. -DPYTHON_EXECUTABLE=$(which python3.7) -DCMAKE_INSTALL_PREFIX=$HOME/nmodl -DCMAKE_BUILD_TYPE=Release -DNMODL_ENABLE_LLVM=ON -DLLVM_DIR=/usr/lib/llvm-13/share/llvm/cmake/ make -j 2 if [ $? -ne 0 ] then @@ -115,24 +120,27 @@ jobs: env: CMAKE_PKG: 'cmake-3.10.2-Linux-x86_64' displayName: 'Build CoreNEURON and Run Integration Tests with ISPC compiler' -- job: 'osx1014' +- job: 'osx1015' pool: - vmImage: 'macOS-10.14' - displayName: 'MacOS (10.14), AppleClang 10.0' + vmImage: 'macOS-10.15' + displayName: 'MacOS (10.15), AppleClang 11.0' steps: - checkout: self submodules: True - script: | - brew install flex cmake python@3 - brew install bison + brew install flex bison cmake python@3 llvm python3 -m pip install -U pip setuptools python3 -m pip install --user 'Jinja2>=2.9.3' 'PyYAML>=3.13' pytest pytest-cov numpy 'sympy>=1.3' displayName: 'Install Dependencies' + - script: | + cd $HOME + git clone --depth 1 https://github.com/pramodk/llvm-nightly.git + displayName: 'Setup LLVM v13' - script: | export PATH=/usr/local/opt/flex/bin:/usr/local/opt/bison/bin:$PATH; mkdir -p $(Build.Repository.LocalPath)/build cd $(Build.Repository.LocalPath)/build - cmake .. -DPYTHON_EXECUTABLE=$(which python3) -DCMAKE_INSTALL_PREFIX=$HOME/nmodl -DCMAKE_BUILD_TYPE=RelWithDebInfo -DNMODL_ENABLE_PYTHON_BINDINGS=OFF + cmake .. -DPYTHON_EXECUTABLE=$(which python3) -DCMAKE_INSTALL_PREFIX=$HOME/nmodl -DCMAKE_BUILD_TYPE=RelWithDebInfo -DNMODL_ENABLE_PYTHON_BINDINGS=OFF -DLLVM_DIR=$HOME/llvm-nightly/0621/osx/lib/cmake/llvm -DNMODL_ENABLE_LLVM=ON make -j 2 if [ $? -ne 0 ] then @@ -171,6 +179,7 @@ jobs: displayName: 'Build Neuron and Run Integration Tests' - job: 'manylinux_wheels' timeoutInMinutes: 45 + condition: eq(1,2) pool: vmImage: 'ubuntu-18.04' strategy: @@ -220,6 +229,7 @@ jobs: - template: ci/upload-wheels.yml - job: 'macos_wheels' timeoutInMinutes: 45 + condition: eq(1,2) pool: vmImage: 'macOS-10.15' strategy: diff --git a/ci/bb5-pr.sh b/ci/bb5-pr.sh index 9f65c3783f..abdce2d867 100755 --- a/ci/bb5-pr.sh +++ b/ci/bb5-pr.sh @@ -7,7 +7,7 @@ git show HEAD source /gpfs/bbp.cscs.ch/apps/hpc/jenkins/config/modules.sh module use /gpfs/bbp.cscs.ch/apps/tools/modules/tcl/linux-rhel7-x86_64/ -module load archive/2020-10 cmake bison flex python-dev doxygen +module load unstable cmake bison flex python-dev doxygen module list function bb5_pr_setup_virtualenv() { @@ -41,7 +41,8 @@ function build_with() { -DPYTHON_EXECUTABLE=$(which python3) \ -DNMODL_FORMATTING:BOOL=ON \ -DClangFormat_EXECUTABLE=$clang_format_exe \ - -DLLVM_DIR=/gpfs/bbp.cscs.ch/apps/hpc/jenkins/merge/deploy/externals/latest/linux-rhel7-x86_64/gcc-9.3.0/llvm-11.0.0-kzl4o5/lib/cmake/llvm + -DNMODL_ENABLE_JIT_EVENT_LISTENERS=ON \ + -DLLVM_DIR=/gpfs/bbp.cscs.ch/apps/hpc/llvm-install/0621/lib/cmake/llvm make -j6 popd } @@ -79,7 +80,7 @@ function bb5_pr_build_intel() { } function bb5_pr_build_pgi() { - build_with pgi + build_with nvhpc } function bb5_pr_test_gcc() { @@ -91,7 +92,7 @@ function bb5_pr_test_intel() { } function bb5_pr_test_pgi() { - test_with pgi + test_with nvhpc } function bb5_pr_build_llvm() { diff --git a/cmake/LLVMHelper.cmake b/cmake/LLVMHelper.cmake new file mode 100644 index 0000000000..780ae29cfa --- /dev/null +++ b/cmake/LLVMHelper.cmake @@ -0,0 +1,69 @@ +# ============================================================================= +# LLVM/Clang needs to be linked with either libc++ or libstdc++ +# ============================================================================= + +find_package(LLVM REQUIRED CONFIG) + +# include LLVM libraries +set(NMODL_LLVM_COMPONENTS + analysis + codegen + core + executionengine + instcombine + ipo + mc + native + orcjit + target + transformutils + scalaropts + support) + +if(NMODL_ENABLE_JIT_EVENT_LISTENERS) + list(APPEND NMODL_LLVM_COMPONENTS inteljitevents perfjitevents) +endif() + +llvm_map_components_to_libnames(LLVM_LIBS_TO_LINK ${NMODL_LLVM_COMPONENTS}) + +set(CMAKE_REQUIRED_INCLUDES ${LLVM_INCLUDE_DIRS}) +set(CMAKE_REQUIRED_LIBRARIES ${LLVM_LIBS_TO_LINK}) + +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND NMODL_ENABLE_LLVM) + include(CheckCXXSourceCompiles) + + # simple code to test LLVM library linking + set(CODE_TO_TEST + " + #include + using namespace llvm; + int main(int argc, char* argv[]) { + std::unique_ptr> Builder; + }") + + # first compile without any flags + check_cxx_source_compiles("${CODE_TO_TEST}" LLVM_LIB_LINK_TEST) + + # if standard compilation fails + if(NOT LLVM_LIB_LINK_TEST) + # try libstdc++ first + set(CMAKE_REQUIRED_FLAGS "-stdlib=libstdc++") + check_cxx_source_compiles("${CODE_TO_TEST}" LLVM_LIBSTDCPP_TEST) + # on failure, try libc++ + if(NOT LLVM_LIBSTDCPP_TEST) + set(CMAKE_REQUIRED_FLAGS "-stdlib=libc++") + check_cxx_source_compiles("${CODE_TO_TEST}" LLVM_LIBCPP_TEST) + endif() + # if either library works then add it to CXX flags + if(LLVM_LIBSTDCPP_TEST OR LLVM_LIBCPP_TEST) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CMAKE_REQUIRED_FLAGS}") + message( + STATUS + "Adding ${CMAKE_REQUIRED_FLAGS} to CMAKE_CXX_FLAGS, required to link with LLVM libraries") + else() + message( + STATUS + "WARNING : -stdlib=libstdcx++ or -stdlib=libc++ didn't work to link with LLVM library") + endif() + endif() +endif() diff --git a/setup.py b/setup.py index ec560c6c1e..27539ab4ce 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ def _config_exe(exe_name): ] -cmake_args = ["-DPYTHON_EXECUTABLE=" + sys.executable] +cmake_args = ["-DPYTHON_EXECUTABLE=" + sys.executable, "-DNMODL_ENABLE_LLVM=OFF", "-DNMODL_ENABLE_PYTHON_BINDINGS=ON"] if "bdist_wheel" in sys.argv: cmake_args.append("-DLINK_AGAINST_PYTHON=FALSE") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7b5e67a66a..e4da0b713c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -28,6 +28,9 @@ target_link_libraries( util lexer ${NMODL_WRAPPER_LIBS}) +if(NMODL_ENABLE_LLVM) + target_link_libraries(nmodl llvm_codegen llvm_benchmark benchmark_data ${LLVM_LIBS_TO_LINK}) +endif() # ============================================================================= # Add dependency with nmodl pytnon module (for consumer projects) diff --git a/src/ast/ast_common.hpp b/src/ast/ast_common.hpp index eb854bb5c5..733fc406f7 100644 --- a/src/ast/ast_common.hpp +++ b/src/ast/ast_common.hpp @@ -43,9 +43,12 @@ namespace ast { * * NMODL support different binary operators and this * type is used to store their value in the AST. + * + * \note `+=` and `-=` are not supported by NMODL but they + * are added for code generation nodes. */ typedef enum { - BOP_ADDITION, ///< \+ + BOP_ADDITION = 0, ///< \+ BOP_SUBTRACTION, ///< -- BOP_MULTIPLICATION, ///< \c * BOP_DIVISION, ///< \/ @@ -58,7 +61,9 @@ typedef enum { BOP_LESS_EQUAL, ///< <= BOP_ASSIGN, ///< = BOP_NOT_EQUAL, ///< != - BOP_EXACT_EQUAL ///< == + BOP_EXACT_EQUAL, ///< == + BOP_ADD_ASSIGN, ///< \+= + BOP_SUB_ASSIGN ///< \-= } BinaryOp; /** @@ -68,7 +73,7 @@ typedef enum { * is used to lookup the corresponding symbol for the operator. */ static const std::string BinaryOpNames[] = - {"+", "-", "*", "/", "^", "&&", "||", ">", "<", ">=", "<=", "=", "!=", "=="}; + {"+", "-", "*", "/", "^", "&&", "||", ">", "<", ">=", "<=", "=", "!=", "==", "+=", "-="}; /// enum type for unary operators typedef enum { UOP_NOT, UOP_NEGATION } UnaryOp; @@ -106,6 +111,20 @@ typedef enum { LTMINUSGT, LTLT, MINUSGT } ReactionOp; /// string representation of ast::ReactionOp static const std::string ReactionOpNames[] = {"<->", "<<", "->"}; +/** + * Get corresponding ast::BinaryOp for given string + * @param op Binary operator in string format + * @return ast::BinaryOp for given string + */ +static inline BinaryOp string_to_binaryop(const std::string& op) { + /// check if binary operator supported otherwise error + auto it = std::find(std::begin(BinaryOpNames), std::end(BinaryOpNames), op); + if (it == std::end(BinaryOpNames)) { + throw std::runtime_error("Error in string_to_binaryop, can't find " + op); + } + int pos = std::distance(std::begin(BinaryOpNames), it); + return static_cast(pos); +} /** @} */ // end of ast_prop } // namespace ast diff --git a/src/codegen/CMakeLists.txt b/src/codegen/CMakeLists.txt index 32ad4e1303..2d31e1b1d6 100644 --- a/src/codegen/CMakeLists.txt +++ b/src/codegen/CMakeLists.txt @@ -35,6 +35,11 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/fast_math.ispc configure_file(${CMAKE_CURRENT_SOURCE_DIR}/fast_math.hpp ${CMAKE_BINARY_DIR}/include/nmodl/fast_math.hpp COPYONLY) +# build llvm visitor if enabled +if(NMODL_ENABLE_LLVM) + add_subdirectory(llvm) +endif() + # ============================================================================= # Install include files # ============================================================================= diff --git a/src/codegen/codegen_c_visitor.cpp b/src/codegen/codegen_c_visitor.cpp index 63de87807f..49ec6436de 100644 --- a/src/codegen/codegen_c_visitor.cpp +++ b/src/codegen/codegen_c_visitor.cpp @@ -348,49 +348,6 @@ bool CodegenCVisitor::statement_to_skip(const Statement& node) const { } -bool CodegenCVisitor::net_send_buffer_required() const noexcept { - if (net_receive_required() && !info.artificial_cell) { - if (info.net_event_used || info.net_send_used || info.is_watch_used()) { - return true; - } - } - return false; -} - - -bool CodegenCVisitor::net_receive_buffering_required() const noexcept { - return info.point_process && !info.artificial_cell && info.net_receive_node != nullptr; -} - - -bool CodegenCVisitor::nrn_state_required() const noexcept { - if (info.artificial_cell) { - return false; - } - return info.nrn_state_block != nullptr || info.currents.empty(); -} - - -bool CodegenCVisitor::nrn_cur_required() const noexcept { - return info.breakpoint_node != nullptr && !info.currents.empty(); -} - - -bool CodegenCVisitor::net_receive_exist() const noexcept { - return info.net_receive_node != nullptr; -} - - -bool CodegenCVisitor::breakpoint_exist() const noexcept { - return info.breakpoint_node != nullptr; -} - - -bool CodegenCVisitor::net_receive_required() const noexcept { - return net_receive_exist(); -} - - /** * \details When floating point data type is not default (i.e. double) then we * have to copy old array to new type (for range variables). @@ -415,7 +372,7 @@ bool CodegenCVisitor::state_variable(const std::string& name) const { int CodegenCVisitor::position_of_float_var(const std::string& name) const { int index = 0; - for (const auto& var: codegen_float_variables) { + for (const auto& var: info.codegen_float_variables) { if (var->get_name() == name) { return index; } @@ -427,7 +384,7 @@ int CodegenCVisitor::position_of_float_var(const std::string& name) const { int CodegenCVisitor::position_of_int_var(const std::string& name) const { int index = 0; - for (const auto& var: codegen_int_variables) { + for (const auto& var: info.codegen_int_variables) { if (var.symbol->get_name() == name) { return index; } @@ -546,11 +503,11 @@ int CodegenCVisitor::float_variables_size() const { float_size++; } /// for g_unused variable - if (breakpoint_exist()) { + if (info.breakpoint_exist()) { float_size++; } /// for tsave variable - if (net_receive_exist()) { + if (info.net_receive_exist()) { float_size++; } return float_size; @@ -810,186 +767,6 @@ void CodegenCVisitor::update_index_semantics() { } -std::vector CodegenCVisitor::get_float_variables() { - // sort with definition order - auto comparator = [](const SymbolType& first, const SymbolType& second) -> bool { - return first->get_definition_order() < second->get_definition_order(); - }; - - auto assigned = info.assigned_vars; - auto states = info.state_vars; - - // each state variable has corresponding Dstate variable - for (auto& state: states) { - auto name = "D" + state->get_name(); - auto symbol = make_symbol(name); - if (state->is_array()) { - symbol->set_as_array(state->get_length()); - } - symbol->set_definition_order(state->get_definition_order()); - assigned.push_back(symbol); - } - std::sort(assigned.begin(), assigned.end(), comparator); - - auto variables = info.range_parameter_vars; - variables.insert(variables.end(), - info.range_assigned_vars.begin(), - info.range_assigned_vars.end()); - variables.insert(variables.end(), info.range_state_vars.begin(), info.range_state_vars.end()); - variables.insert(variables.end(), assigned.begin(), assigned.end()); - - if (info.vectorize) { - variables.push_back(make_symbol(naming::VOLTAGE_UNUSED_VARIABLE)); - } - if (breakpoint_exist()) { - std::string name = info.vectorize ? naming::CONDUCTANCE_UNUSED_VARIABLE - : naming::CONDUCTANCE_VARIABLE; - variables.push_back(make_symbol(name)); - } - if (net_receive_exist()) { - variables.push_back(make_symbol(naming::T_SAVE_VARIABLE)); - } - return variables; -} - - -/** - * IndexVariableInfo has following constructor arguments: - * - symbol - * - is_vdata (false) - * - is_index (false - * - is_integer (false) - * - * Which variables are constant qualified? - * - * - node area is read only - * - read ion variables are read only - * - style_ionname is index / offset - */ -std::vector CodegenCVisitor::get_int_variables() { - std::vector variables; - if (info.point_process) { - variables.emplace_back(make_symbol(naming::NODE_AREA_VARIABLE)); - variables.back().is_constant = true; - /// note that this variable is not printed in neuron implementation - if (info.artificial_cell) { - variables.emplace_back(make_symbol(naming::POINT_PROCESS_VARIABLE), true); - } else { - variables.emplace_back(make_symbol(naming::POINT_PROCESS_VARIABLE), false, false, true); - variables.back().is_constant = true; - } - } - - for (const auto& ion: info.ions) { - bool need_style = false; - std::unordered_map ion_vars; // used to keep track of the variables to - // not have doubles between read/write. Same - // name variables are allowed - for (const auto& var: ion.reads) { - const std::string name = "ion_" + var; - variables.emplace_back(make_symbol(name)); - variables.back().is_constant = true; - ion_vars[name] = variables.size() - 1; - } - - /// symbol for di_ion_dv var - std::shared_ptr ion_di_dv_var = nullptr; - - for (const auto& var: ion.writes) { - const std::string name = "ion_" + var; - - const auto ion_vars_it = ion_vars.find(name); - if (ion_vars_it != ion_vars.end()) { - variables[ion_vars_it->second].is_constant = false; - } else { - variables.emplace_back(make_symbol("ion_" + var)); - } - if (ion.is_ionic_current(var)) { - ion_di_dv_var = make_symbol("ion_di" + ion.name + "dv"); - } - if (ion.is_intra_cell_conc(var) || ion.is_extra_cell_conc(var)) { - need_style = true; - } - } - - /// insert after read/write variables but before style ion variable - if (ion_di_dv_var != nullptr) { - variables.emplace_back(ion_di_dv_var); - } - - if (need_style) { - variables.emplace_back(make_symbol("style_" + ion.name), false, true); - variables.back().is_constant = true; - } - } - - for (const auto& var: info.pointer_variables) { - auto name = var->get_name(); - if (var->has_any_property(NmodlType::pointer_var)) { - variables.emplace_back(make_symbol(name)); - } else { - variables.emplace_back(make_symbol(name), true); - } - } - - if (info.diam_used) { - variables.emplace_back(make_symbol(naming::DIAM_VARIABLE)); - } - - if (info.area_used) { - variables.emplace_back(make_symbol(naming::AREA_VARIABLE)); - } - - // for non-artificial cell, when net_receive buffering is enabled - // then tqitem is an offset - if (info.net_send_used) { - if (info.artificial_cell) { - variables.emplace_back(make_symbol(naming::TQITEM_VARIABLE), true); - } else { - variables.emplace_back(make_symbol(naming::TQITEM_VARIABLE), false, false, true); - variables.back().is_constant = true; - } - info.tqitem_index = variables.size() - 1; - } - - /** - * \note Variables for watch statements : there is one extra variable - * used in coreneuron compared to actual watch statements for compatibility - * with neuron (which uses one extra Datum variable) - */ - if (!info.watch_statements.empty()) { - for (int i = 0; i < info.watch_statements.size() + 1; i++) { - variables.emplace_back(make_symbol("watch{}"_format(i)), false, false, true); - } - } - return variables; -} - - -/** - * \details When we enable fine level parallelism at channel level, we have do updates - * to ion variables in atomic way. As cpus don't have atomic instructions in - * simd loop, we have to use shadow vectors for every ion variables. Here - * we return list of all such variables. - * - * \todo If conductances are specified, we don't need all below variables - */ -std::vector CodegenCVisitor::get_shadow_variables() { - std::vector variables; - for (const auto& ion: info.ions) { - for (const auto& var: ion.writes) { - variables.push_back({make_symbol(shadow_varname("ion_" + var))}); - if (ion.is_ionic_current(var)) { - variables.push_back({make_symbol(shadow_varname("ion_di" + ion.name + "dv"))}); - } - } - } - variables.push_back({make_symbol("ml_rhs")}); - variables.push_back({make_symbol("ml_d")}); - return variables; -} - - /****************************************************************************************/ /* Routines must be overloaded in backend */ /****************************************************************************************/ @@ -1078,7 +855,7 @@ bool CodegenCVisitor::nrn_cur_reduction_loop_required() { bool CodegenCVisitor::shadow_vector_setup_required() { - return (channel_task_dependency_enabled() && !codegen_shadow_variables.empty()); + return (channel_task_dependency_enabled() && !info.codegen_shadow_variables.empty()); } @@ -1933,8 +1710,8 @@ std::string CodegenCVisitor::process_verbatim_text(std::string text) { std::string CodegenCVisitor::register_mechanism_arguments() const { - auto nrn_cur = nrn_cur_required() ? method_name(naming::NRN_CUR_METHOD) : "NULL"; - auto nrn_state = nrn_state_required() ? method_name(naming::NRN_STATE_METHOD) : "NULL"; + auto nrn_cur = info.nrn_cur_required() ? method_name(naming::NRN_CUR_METHOD) : "NULL"; + auto nrn_state = info.nrn_state_required() ? method_name(naming::NRN_STATE_METHOD) : "NULL"; auto nrn_alloc = method_name(naming::NRN_ALLOC_METHOD); auto nrn_init = method_name(naming::NRN_INIT_METHOD); return "mechanism, {}, {}, NULL, {}, {}, first_pointer_var_index()" @@ -2052,7 +1829,7 @@ void CodegenCVisitor::print_num_variable_getter() { void CodegenCVisitor::print_net_receive_arg_size_getter() { - if (!net_receive_exist()) { + if (!info.net_receive_exist()) { return; } printer->add_newline(2); @@ -2245,17 +2022,18 @@ std::string CodegenCVisitor::get_variable_name(const std::string& name, bool use // clang-format on // float variable - auto f = std::find_if(codegen_float_variables.begin(), - codegen_float_variables.end(), + auto f = std::find_if(info.codegen_float_variables.begin(), + info.codegen_float_variables.end(), symbol_comparator); - if (f != codegen_float_variables.end()) { + if (f != info.codegen_float_variables.end()) { return float_variable_name(*f, use_instance); } // integer variable - auto i = - std::find_if(codegen_int_variables.begin(), codegen_int_variables.end(), index_comparator); - if (i != codegen_int_variables.end()) { + auto i = std::find_if(info.codegen_int_variables.begin(), + info.codegen_int_variables.end(), + index_comparator); + if (i != info.codegen_int_variables.end()) { return int_variable_name(*i, varname, use_instance); } @@ -2268,10 +2046,10 @@ std::string CodegenCVisitor::get_variable_name(const std::string& name, bool use } // shadow variable - auto s = std::find_if(codegen_shadow_variables.begin(), - codegen_shadow_variables.end(), + auto s = std::find_if(info.codegen_shadow_variables.begin(), + info.codegen_shadow_variables.end(), symbol_comparator); - if (s != codegen_shadow_variables.end()) { + if (s != info.codegen_shadow_variables.end()) { return ion_shadow_variable_name(*s); } @@ -2700,7 +2478,7 @@ void CodegenCVisitor::print_mechanism_register() { if (info.artificial_cell) { printer->add_line("add_nrn_artcell(mech_type, {});"_format(info.tqitem_index)); } - if (net_receive_buffering_required()) { + if (info.net_receive_buffering_required()) { printer->add_line("hoc_register_net_receive_buffering({}, mech_type);"_format( method_name("net_buf_receive"))); } @@ -2801,13 +2579,13 @@ void CodegenCVisitor::print_mechanism_range_var_structure() { printer->add_newline(2); printer->add_line("/** all mechanism instance variables */"); printer->start_block("struct {} "_format(instance_struct())); - for (auto& var: codegen_float_variables) { + for (auto& var: info.codegen_float_variables) { auto name = var->get_name(); auto type = get_range_var_float_type(var); auto qualifier = is_constant_variable(name) ? k_const() : ""; printer->add_line("{}{}* {}{};"_format(qualifier, type, ptr_type_qualifier(), name)); } - for (auto& var: codegen_int_variables) { + for (auto& var: info.codegen_int_variables) { auto name = var.symbol->get_name(); if (var.is_index || var.is_integer) { auto qualifier = var.is_constant ? k_const() : ""; @@ -2820,7 +2598,7 @@ void CodegenCVisitor::print_mechanism_range_var_structure() { } } if (channel_task_dependency_enabled()) { - for (auto& var: codegen_shadow_variables) { + for (auto& var: info.codegen_shadow_variables) { auto name = var->get_name(); printer->add_line("{}* {}{};"_format(float_type, ptr_type_qualifier(), name)); } @@ -3029,7 +2807,7 @@ void CodegenCVisitor::print_shadow_vector_setup() { printer->start_block("static inline void setup_shadow_vectors({}) "_format(args)); if (channel_task_dependency_enabled()) { printer->add_line("int nodecount = ml->nodecount;"); - for (auto& var: codegen_shadow_variables) { + for (auto& var: info.codegen_shadow_variables) { auto name = var->get_name(); auto type = default_float_data_type(); auto allocation = "({0}*) mem_alloc(nodecount, sizeof({0}))"_format(type); @@ -3042,7 +2820,7 @@ void CodegenCVisitor::print_shadow_vector_setup() { args = "{}* inst"_format(instance_struct()); printer->start_block("static inline void free_shadow_vectors({}) "_format(args)); if (channel_task_dependency_enabled()) { - for (auto& var: codegen_shadow_variables) { + for (auto& var: info.codegen_shadow_variables) { auto name = var->get_name(); printer->add_line("mem_free(inst->{});"_format(name)); } @@ -3109,7 +2887,7 @@ void CodegenCVisitor::print_instance_variable_setup() { printer->add_line("/** initialize mechanism instance variables */"); printer->start_block("static inline void setup_instance(NrnThread* nt, Memb_list* ml) "); printer->add_line("{0}* inst = ({0}*) mem_alloc(1, sizeof({0}));"_format(instance_struct())); - if (channel_task_dependency_enabled() && !codegen_shadow_variables.empty()) { + if (channel_task_dependency_enabled() && !info.codegen_shadow_variables.empty()) { printer->add_line("setup_shadow_vectors(inst, ml);"); } @@ -3127,7 +2905,7 @@ void CodegenCVisitor::print_instance_variable_setup() { int id = 0; std::vector variables_to_free; - for (auto& var: codegen_float_variables) { + for (auto& var: info.codegen_float_variables) { auto name = var->get_name(); auto range_var_type = get_range_var_float_type(var); if (float_type == range_var_type) { @@ -3142,7 +2920,7 @@ void CodegenCVisitor::print_instance_variable_setup() { id += var->get_length(); } - for (auto& var: codegen_int_variables) { + for (auto& var: info.codegen_int_variables) { auto name = var.symbol->get_name(); std::string variable = name; std::string type = ""; @@ -3681,7 +3459,7 @@ void CodegenCVisitor::print_net_receive_loop_end() { void CodegenCVisitor::print_net_receive_buffering(bool need_mech_inst) { - if (!net_receive_required() || info.artificial_cell) { + if (!info.net_receive_required() || info.artificial_cell) { return; } printer->add_newline(2); @@ -3730,7 +3508,7 @@ void CodegenCVisitor::print_net_send_buffering_grow() { } void CodegenCVisitor::print_net_send_buffering() { - if (!net_send_buffer_required()) { + if (!info.net_send_buffer_required()) { return; } @@ -3796,7 +3574,7 @@ void CodegenCVisitor::visit_for_netcon(const ast::ForNetcon& node) { } void CodegenCVisitor::print_net_receive_kernel() { - if (!net_receive_required()) { + if (!info.net_receive_required()) { return; } codegen = true; @@ -3859,7 +3637,7 @@ void CodegenCVisitor::print_net_receive_kernel() { void CodegenCVisitor::print_net_receive() { - if (!net_receive_required()) { + if (!info.net_receive_required()) { return; } codegen = true; @@ -4007,7 +3785,7 @@ void CodegenCVisitor::visit_solution_expression(const SolutionExpression& node) void CodegenCVisitor::print_nrn_state() { - if (!nrn_state_required()) { + if (!info.nrn_state_required()) { return; } codegen = true; @@ -4217,7 +3995,7 @@ void CodegenCVisitor::print_fast_imem_calculation() { } void CodegenCVisitor::print_nrn_cur() { - if (!nrn_cur_required()) { + if (!info.nrn_cur_required()) { return; } @@ -4365,10 +4143,6 @@ void CodegenCVisitor::setup(const Program& node) { logger->warn("CodegenCVisitor : MOD file uses non-thread safe constructs of NMODL"); } - codegen_float_variables = get_float_variables(); - codegen_int_variables = get_int_variables(); - codegen_shadow_variables = get_shadow_variables(); - update_index_semantics(); rename_function_arguments(); } diff --git a/src/codegen/codegen_c_visitor.hpp b/src/codegen/codegen_c_visitor.hpp index 87dad2d3ef..64f4477eeb 100644 --- a/src/codegen/codegen_c_visitor.hpp +++ b/src/codegen/codegen_c_visitor.hpp @@ -46,40 +46,6 @@ namespace codegen { * @{ */ -/** - * \enum BlockType - * \brief Helper to represent various block types - * - * Note: do not assign integers to these enums - * - */ -enum BlockType { - /// initial block - Initial, - - /// destructor block - Destructor, - - /// breakpoint block - Equation, - - /// ode_* routines block (not used) - Ode, - - /// derivative block - State, - - /// watch block - Watch, - - /// net_receive block - NetReceive, - - /// fake ending block type for loops on the enums. Keep it at the end - BlockTypeEnd -}; - - /** * \enum MemberType * \brief Helper to represent various variables types @@ -99,57 +65,6 @@ enum class MemberType { thread }; - -/** - * \class IndexVariableInfo - * \brief Helper to represent information about index/int variables - * - */ -struct IndexVariableInfo { - /// symbol for the variable - const std::shared_ptr symbol; - - /// if variable reside in vdata field of NrnThread - /// typically true for bbcore pointer - bool is_vdata = false; - - /// if this is pure index (e.g. style_ion) variables is directly - /// index and shouldn't be printed with data/vdata - bool is_index = false; - - /// if this is an integer (e.g. tqitem, point_process) variable which - /// is printed as array accesses - bool is_integer = false; - - /// if the variable is qualified as constant (this is property of IndexVariable) - bool is_constant = false; - - IndexVariableInfo(std::shared_ptr symbol, - bool is_vdata = false, - bool is_index = false, - bool is_integer = false) - : symbol(std::move(symbol)) - , is_vdata(is_vdata) - , is_index(is_index) - , is_integer(is_integer) {} -}; - - -/** - * \class ShadowUseStatement - * \brief Represents ion write statement during code generation - * - * Ion update statement needs use of shadow vectors for certain backends - * as atomics operations are not supported on cpu backend. - * - * \todo If shadow_lhs is empty then we assume shadow statement not required - */ -struct ShadowUseStatement { - std::string lhs; - std::string op; - std::string rhs; -}; - /** @} */ // end of codegen_details @@ -213,11 +128,6 @@ class CodegenCVisitor: public visitor::ConstAstVisitor { */ symtab::SymbolTable* program_symtab = nullptr; - /** - * All float variables for the model - */ - std::vector codegen_float_variables; - /** * All int variables for the model */ @@ -406,26 +316,6 @@ class CodegenCVisitor: public visitor::ConstAstVisitor { } - /** - * Constructs a shadow variable name - * \param name The name of the variable - * \return The name of the variable prefixed with \c shadow_ - */ - std::string shadow_varname(const std::string& name) const { - return "shadow_" + name; - } - - - /** - * Creates a temporary symbol - * \param name The name of the symbol - * \return A symbol based on the given name - */ - SymbolType make_symbol(const std::string& name) const { - return std::make_shared(name, ModToken()); - } - - /** * Checks if the given variable name belongs to a state variable * \param name The variable name @@ -434,36 +324,6 @@ class CodegenCVisitor: public visitor::ConstAstVisitor { bool state_variable(const std::string& name) const; - /** - * Check if net receive/send buffering kernels required - */ - bool net_receive_buffering_required() const noexcept; - - - /** - * Check if nrn_state function is required - */ - bool nrn_state_required() const noexcept; - - - /** - * Check if nrn_cur function is required - */ - bool nrn_cur_required() const noexcept; - - - /** - * Check if net_receive function is required - */ - bool net_receive_required() const noexcept; - - - /** - * Check if net_send_buffer is required - */ - bool net_send_buffer_required() const noexcept; - - /** * Check if setup_range_variable function is required * \return @@ -471,18 +331,6 @@ class CodegenCVisitor: public visitor::ConstAstVisitor { bool range_variable_setup_required() const noexcept; - /** - * Check if net_receive node exist - */ - bool net_receive_exist() const noexcept; - - - /** - * Check if breakpoint node exist - */ - bool breakpoint_exist() const noexcept; - - /** * Check if given method is defined in this model * \param name The name of the method to check @@ -648,27 +496,6 @@ class CodegenCVisitor: public visitor::ConstAstVisitor { void update_index_semantics(); - /** - * Determine all \c float variables required during code generation - * \return A \c vector of \c float variables - */ - std::vector get_float_variables(); - - - /** - * Determine all \c int variables required during code generation - * \return A \c vector of \c int variables - */ - std::vector get_int_variables(); - - - /** - * Determine all ion write variables that require shadow vectors during code generation - * \return A \c vector of ion variables - */ - std::vector get_shadow_variables(); - - /** * Print the items in a vector as a list * diff --git a/src/codegen/codegen_helper_visitor.cpp b/src/codegen/codegen_helper_visitor.cpp index 38e5c3c1e0..9c4944a23e 100644 --- a/src/codegen/codegen_helper_visitor.cpp +++ b/src/codegen/codegen_helper_visitor.cpp @@ -22,6 +22,7 @@ using namespace ast; using symtab::syminfo::NmodlType; using symtab::syminfo::Status; + /** * How symbols are stored in NEURON? See notes written in markdown file. * @@ -273,6 +274,7 @@ void CodegenHelperVisitor::find_non_range_variables() { // clang-format on } + /** * Find range variables i.e. ones that are belong to per instance allocation * @@ -664,6 +666,9 @@ void CodegenHelperVisitor::visit_program(const ast::Program& node) { find_range_variables(); find_non_range_variables(); find_table_variables(); + info.get_int_variables(); + info.get_shadow_variables(); + info.get_float_variables(); } diff --git a/src/codegen/codegen_helper_visitor.hpp b/src/codegen/codegen_helper_visitor.hpp index 4f32d1cef8..a6fd10a16b 100644 --- a/src/codegen/codegen_helper_visitor.hpp +++ b/src/codegen/codegen_helper_visitor.hpp @@ -75,6 +75,16 @@ class CodegenHelperVisitor: public visitor::ConstAstVisitor { void find_non_range_variables(); void sort_with_mod2c_symbol_order(std::vector& symbols) const; + /** + * Check if breakpoint node exist + */ + bool breakpoint_exist() const noexcept; + + /** + * Check if net_receive node exist + */ + bool net_receive_exist() const noexcept; + public: CodegenHelperVisitor() = default; diff --git a/src/codegen/codegen_info.cpp b/src/codegen/codegen_info.cpp index 8f6bd448f8..26696fbc18 100644 --- a/src/codegen/codegen_info.cpp +++ b/src/codegen/codegen_info.cpp @@ -8,6 +8,7 @@ #include "codegen/codegen_info.hpp" #include "ast/all.hpp" +#include "utils/logger.hpp" #include "visitors/var_usage_visitor.hpp" #include "visitors/visitor_utils.hpp" @@ -15,8 +16,20 @@ namespace nmodl { namespace codegen { +using namespace fmt::literals; +using symtab::syminfo::NmodlType; using visitor::VarUsageVisitor; +SymbolType make_symbol(const std::string& name) { + return std::make_shared(name, ModToken()); +} + + +std::string shadow_varname(const std::string& name) { + return "shadow_" + name; +} + + /// if any ion has write variable bool CodegenInfo::ion_has_write_variable() const { for (const auto& ion: ions) { @@ -131,5 +144,263 @@ bool CodegenInfo::is_voltage_used_by_watch_statements() const { return false; } +bool CodegenInfo::state_variable(const std::string& name) const { + // clang-format off + auto result = std::find_if(state_vars.begin(), + state_vars.end(), + [&name](const SymbolType& sym) { + return name == sym->get_name(); + } + ); + // clang-format on + return result != state_vars.end(); +} + +std::pair CodegenInfo::read_ion_variable_name( + const std::string& name) const { + return {name, "ion_" + name}; +} + + +std::pair CodegenInfo::write_ion_variable_name( + const std::string& name) const { + return {"ion_" + name, name}; +} + + +/** + * \details Current variable used in breakpoint block could be local variable. + * In this case, neuron has already renamed the variable name by prepending + * "_l". In our implementation, the variable could have been renamed by + * one of the pass. And hence, we search all local variables and check if + * the variable is renamed. Note that we have to look into the symbol table + * of statement block and not breakpoint. + */ +std::string CodegenInfo::breakpoint_current(std::string current) const { + auto& breakpoint = breakpoint_node; + if (breakpoint == nullptr) { + return current; + } + const auto& symtab = breakpoint->get_statement_block()->get_symbol_table(); + const auto& variables = symtab->get_variables_with_properties(NmodlType::local_var); + for (const auto& var: variables) { + std::string renamed_name = var->get_name(); + std::string original_name = var->get_original_name(); + if (current == original_name) { + current = renamed_name; + break; + } + } + return current; +} + + +bool CodegenInfo::is_an_instance_variable(const std::string& varname) const { + /// check if symbol of given name exist + auto check_symbol = [](const std::string& name, const std::vector& symbols) { + for (auto& symbol: symbols) { + if (symbol->get_name() == name) { + return true; + } + } + return false; + }; + + /// check if variable exist into all possible types + if (check_symbol(varname, assigned_vars) || check_symbol(varname, state_vars) || + check_symbol(varname, range_parameter_vars) || check_symbol(varname, range_assigned_vars) || + check_symbol(varname, range_state_vars)) { + return true; + } + return false; +} + + +/** + * IndexVariableInfo has following constructor arguments: + * - symbol + * - is_vdata (false) + * - is_index (false + * - is_integer (false) + * + * Which variables are constant qualified? + * + * - node area is read only + * - read ion variables are read only + * - style_ionname is index / offset + */ +void CodegenInfo::get_int_variables() { + if (point_process) { + codegen_int_variables.emplace_back(make_symbol(naming::NODE_AREA_VARIABLE)); + codegen_int_variables.back().is_constant = true; + /// note that this variable is not printed in neuron implementation + if (artificial_cell) { + codegen_int_variables.emplace_back(make_symbol(naming::POINT_PROCESS_VARIABLE), true); + } else { + codegen_int_variables.emplace_back(make_symbol(naming::POINT_PROCESS_VARIABLE), + false, + false, + true); + codegen_int_variables.back().is_constant = true; + } + } + + for (const auto& ion: ions) { + bool need_style = false; + std::unordered_map ion_vars; // used to keep track of the variables to + // not have doubles between read/write. Same + // name variables are allowed + for (const auto& var: ion.reads) { + const std::string name = "ion_" + var; + codegen_int_variables.emplace_back(make_symbol(name)); + codegen_int_variables.back().is_constant = true; + ion_vars[name] = codegen_int_variables.size() - 1; + } + + /// symbol for di_ion_dv var + std::shared_ptr ion_di_dv_var = nullptr; + + for (const auto& var: ion.writes) { + const std::string name = "ion_" + var; + + const auto ion_vars_it = ion_vars.find(name); + if (ion_vars_it != ion_vars.end()) { + codegen_int_variables[ion_vars_it->second].is_constant = false; + } else { + codegen_int_variables.emplace_back(make_symbol("ion_" + var)); + } + if (ion.is_ionic_current(var)) { + ion_di_dv_var = make_symbol("ion_di" + ion.name + "dv"); + } + if (ion.is_intra_cell_conc(var) || ion.is_extra_cell_conc(var)) { + need_style = true; + } + } + + /// insert after read/write variables but before style ion variable + if (ion_di_dv_var != nullptr) { + codegen_int_variables.emplace_back(ion_di_dv_var); + } + + if (need_style) { + codegen_int_variables.emplace_back(make_symbol("style_" + ion.name), false, true); + codegen_int_variables.back().is_constant = true; + } + } + + for (const auto& var: pointer_variables) { + auto name = var->get_name(); + if (var->has_any_property(NmodlType::pointer_var)) { + codegen_int_variables.emplace_back(make_symbol(name)); + } else { + codegen_int_variables.emplace_back(make_symbol(name), true); + } + } + + if (diam_used) { + codegen_int_variables.emplace_back(make_symbol(naming::DIAM_VARIABLE)); + } + + if (area_used) { + codegen_int_variables.emplace_back(make_symbol(naming::AREA_VARIABLE)); + } + + // for non-artificial cell, when net_receive buffering is enabled + // then tqitem is an offset + if (net_send_used) { + if (artificial_cell) { + codegen_int_variables.emplace_back(make_symbol(naming::TQITEM_VARIABLE), true); + } else { + codegen_int_variables.emplace_back(make_symbol(naming::TQITEM_VARIABLE), + false, + false, + true); + codegen_int_variables.back().is_constant = true; + } + tqitem_index = codegen_int_variables.size() - 1; + } + + /** + * \note Variables for watch statements : there is one extra variable + * used in coreneuron compared to actual watch statements for compatibility + * with neuron (which uses one extra Datum variable) + */ + if (!watch_statements.empty()) { + for (int i = 0; i < watch_statements.size() + 1; i++) { + codegen_int_variables.emplace_back(make_symbol("watch{}"_format(i)), + false, + false, + true); + } + } +} + + +/** + * \details When we enable fine level parallelism at channel level, we have do updates + * to ion variables in atomic way. As cpus don't have atomic instructions in + * simd loop, we have to use shadow vectors for every ion variables. Here + * we return list of all such variables. + * + * \todo If conductances are specified, we don't need all below variables + */ +void CodegenInfo::get_shadow_variables() { + for (const auto& ion: ions) { + for (const auto& var: ion.writes) { + codegen_shadow_variables.push_back({make_symbol(shadow_varname("ion_" + var))}); + if (ion.is_ionic_current(var)) { + codegen_shadow_variables.push_back( + {make_symbol(shadow_varname("ion_di" + ion.name + "dv"))}); + } + } + } + codegen_shadow_variables.push_back({make_symbol("ml_rhs")}); + codegen_shadow_variables.push_back({make_symbol("ml_d")}); +} + + +void CodegenInfo::get_float_variables() { + // sort with definition order + auto comparator = [](const SymbolType& first, const SymbolType& second) -> bool { + return first->get_definition_order() < second->get_definition_order(); + }; + + auto assigned = assigned_vars; + auto states = state_vars; + + // each state variable has corresponding Dstate variable + for (auto& state: states) { + auto name = "D" + state->get_name(); + auto symbol = make_symbol(name); + if (state->is_array()) { + symbol->set_as_array(state->get_length()); + } + symbol->set_definition_order(state->get_definition_order()); + assigned.push_back(symbol); + } + std::sort(assigned.begin(), assigned.end(), comparator); + + codegen_float_variables = range_parameter_vars; + codegen_float_variables.insert(codegen_float_variables.end(), + range_assigned_vars.begin(), + range_assigned_vars.end()); + codegen_float_variables.insert(codegen_float_variables.end(), + range_state_vars.begin(), + range_state_vars.end()); + codegen_float_variables.insert(codegen_float_variables.end(), assigned.begin(), assigned.end()); + + if (vectorize) { + codegen_float_variables.push_back(make_symbol(naming::VOLTAGE_UNUSED_VARIABLE)); + } + if (breakpoint_exist()) { + std::string name = vectorize ? naming::CONDUCTANCE_UNUSED_VARIABLE + : naming::CONDUCTANCE_VARIABLE; + codegen_float_variables.push_back(make_symbol(name)); + } + if (net_receive_exist()) { + codegen_float_variables.push_back(make_symbol(naming::T_SAVE_VARIABLE)); + } +} + } // namespace codegen } // namespace nmodl diff --git a/src/codegen/codegen_info.hpp b/src/codegen/codegen_info.hpp index 2df99d7c1c..b0a41583b5 100644 --- a/src/codegen/codegen_info.hpp +++ b/src/codegen/codegen_info.hpp @@ -15,11 +15,62 @@ #include #include "ast/ast.hpp" +#include "codegen/codegen_naming.hpp" #include "symtab/symbol_table.hpp" namespace nmodl { namespace codegen { +using SymbolType = std::shared_ptr; + +/** + * Creates a temporary symbol + * \param name The name of the symbol + * \return A symbol based on the given name + */ +SymbolType make_symbol(const std::string& name); + +/** + * Constructs a shadow variable name + * \param name The name of the variable + * \return The name of the variable prefixed with \c shadow_ + */ +std::string shadow_varname(const std::string& name); + +/** + * \class IndexVariableInfo + * \brief Helper to represent information about index/int variables + * + */ +struct IndexVariableInfo { + /// symbol for the variable + const std::shared_ptr symbol; + + /// if variable reside in vdata field of NrnThread + /// typically true for bbcore pointer + bool is_vdata = false; + + /// if this is pure index (e.g. style_ion) variables is directly + /// index and shouldn't be printed with data/vdata + bool is_index = false; + + /// if this is an integer (e.g. tqitem, point_process) variable which + /// is printed as array accesses + bool is_integer = false; + + /// if the variable is qualified as constant (this is property of IndexVariable) + bool is_constant = false; + + IndexVariableInfo(std::shared_ptr symbol, + bool is_vdata = false, + bool is_index = false, + bool is_integer = false) + : symbol(std::move(symbol)) + , is_vdata(is_vdata) + , is_index(is_index) + , is_integer(is_integer) {} +}; + /** * @addtogroup codegen_details * @{ @@ -126,6 +177,59 @@ struct IndexSemantics { , size(size) {} }; +/** + * \enum BlockType + * \brief Helper to represent various block types + * + * Note: do not assign integers to these enums + * + */ +enum BlockType { + /// initial block + Initial, + + /// destructor block + Destructor, + + /// breakpoint block + Equation, + + /// ode_* routines block (not used) + Ode, + + /// derivative block + State, + + /// watch block + Watch, + + /// net_receive block + NetReceive, + + /// fake ending block type for loops on the enums. Keep it at the end + BlockTypeEnd +}; + +/** + * \class ShadowUseStatement + * \brief Represents ion write statement during code generation + * + * Ion update statement needs use of shadow vectors for certain backends + * as atomics operations are not supported on cpu backend. + * + * \todo Currently `nrn_wrote_conc` is also added to shadow update statements + * list as it's corresponding to ion update statement in INITIAL block. This + * needs to be factored out. + * \todo This can be represented as AST node (like ast::CodegenAtomicStatement) + * but currently C backend use this same implementation. So we are using this + * same structure and then converting to ast::CodegenAtomicStatement for LLVM + * visitor. + */ +struct ShadowUseStatement { + std::string lhs; + std::string op; + std::string rhs; +}; /** * \class CodegenInfo @@ -320,6 +424,15 @@ struct CodegenInfo { /// new one used in print_ion_types std::vector use_ion_variables; + /// all int variables for the model + std::vector codegen_int_variables; + + /// all ion variables that could be possibly written + std::vector codegen_shadow_variables; + + /// all float variables for the model + std::vector codegen_float_variables; + /// this is the order in which they appear in derivative block /// this is required while printing them in initlist function std::vector prime_variables_by_order; @@ -398,8 +511,124 @@ struct CodegenInfo { /// true if WatchStatement uses voltage v variable bool is_voltage_used_by_watch_statements() const; + /** + * Check if net_send_buffer is required + */ + bool net_send_buffer_required() const noexcept { + if (net_receive_required() && !artificial_cell) { + if (net_event_used || net_send_used || is_watch_used()) { + return true; + } + } + return false; + } + + /** + * Check if net receive/send buffering kernels required + */ + bool net_receive_buffering_required() const noexcept { + return point_process && !artificial_cell && net_receive_node != nullptr; + } + + /** + * Check if nrn_state function is required + */ + bool nrn_state_required() const noexcept { + if (artificial_cell) { + return false; + } + return nrn_state_block != nullptr || currents.empty(); + } + + /** + * Check if nrn_cur function is required + */ + bool nrn_cur_required() const noexcept { + return breakpoint_node != nullptr && !currents.empty(); + } + + /** + * Check if net_receive node exist + */ + bool net_receive_exist() const noexcept { + return net_receive_node != nullptr; + } + + /** + * Check if breakpoint node exist + */ + bool breakpoint_exist() const noexcept { + return breakpoint_node != nullptr; + } + + + /** + * Check if net_receive function is required + */ + bool net_receive_required() const noexcept { + return net_receive_exist(); + } + + /** + * Checks if the given variable name belongs to a state variable + * \param name The variable name + * \return \c true if the variable is a state variable + */ + bool state_variable(const std::string& name) const; + + /** + * Return ion variable name and corresponding ion read variable name + * \param name The ion variable name + * \return The ion read variable name + */ + std::pair read_ion_variable_name(const std::string& name) const; + + /** + * Return ion variable name and corresponding ion write variable name + * \param name The ion variable name + * \return The ion write variable name + */ + std::pair write_ion_variable_name(const std::string& name) const; + + /** + * Determine the variable name for the "current" used in breakpoint block taking into account + * intermediate code transformations. + * \param current The variable name for the current used in the model + * \return The name for the current to be printed in C + */ + std::string breakpoint_current(std::string current) const; + + /** + * Check if variable with given name is an instance variable + * + * Instance varaibles are local to each mechanism instance and + * needs to be accessed with an array index. Such variables are + * assigned, range, parameter+range etc. + * @param varname Name of the variable + * @return True if variable is per mechanism instance + */ + bool is_an_instance_variable(const std::string& varname) const; + /// if we need a call back to wrote_conc in neuron/coreneuron bool require_wrote_conc = false; + + /** + * Determine all \c int variables required during code generation + * \return A \c vector of \c int variables + */ + void get_int_variables(); + + /** + * Determine all ion write variables that require shadow vectors during code generation + * \return A \c vector of ion variables + */ + void get_shadow_variables(); + + /** + * Determine all \c float variables required during code generation + * \return A \c vector of \c float variables + */ + void get_float_variables(); }; /** @} */ // end of codegen_backends diff --git a/src/codegen/codegen_ispc_visitor.cpp b/src/codegen/codegen_ispc_visitor.cpp index b2822f1078..808aad1690 100644 --- a/src/codegen/codegen_ispc_visitor.cpp +++ b/src/codegen/codegen_ispc_visitor.cpp @@ -437,7 +437,7 @@ void CodegenIspcVisitor::print_ion_variable() { /****************************************************************************************/ void CodegenIspcVisitor::print_net_receive_buffering_wrapper() { - if (!net_receive_required() || info.artificial_cell) { + if (!info.net_receive_required() || info.artificial_cell) { return; } printer->add_newline(2); @@ -515,19 +515,19 @@ void CodegenIspcVisitor::print_backend_compute_routine_decl() { "extern \"C\" void {}({});"_format(compute_function, get_parameter_str(params))); } - if (nrn_cur_required() && !emit_fallback[BlockType::Equation]) { + if (info.nrn_cur_required() && !emit_fallback[BlockType::Equation]) { compute_function = compute_method_name(BlockType::Equation); printer->add_line( "extern \"C\" void {}({});"_format(compute_function, get_parameter_str(params))); } - if (nrn_state_required() && !emit_fallback[BlockType::State]) { + if (info.nrn_state_required() && !emit_fallback[BlockType::State]) { compute_function = compute_method_name(BlockType::State); printer->add_line( "extern \"C\" void {}({});"_format(compute_function, get_parameter_str(params))); } - if (net_receive_required()) { + if (info.net_receive_required()) { auto net_recv_params = ParamVector(); net_recv_params.emplace_back("", "{}*"_format(instance_struct()), "", "inst"); net_recv_params.emplace_back("", "NrnThread*", "", "nt"); @@ -547,7 +547,7 @@ bool CodegenIspcVisitor::check_incompatibilities() { }; // instance vars - if (check_incompatible_var_name(codegen_float_variables, + if (check_incompatible_var_name(info.codegen_float_variables, get_name_from_symbol_type_vector)) { return true; } @@ -613,11 +613,11 @@ bool CodegenIspcVisitor::check_incompatibilities() { visitor::calls_function(*info.net_receive_node, "net_send"))); emit_fallback[BlockType::Equation] = emit_fallback[BlockType::Equation] || - (nrn_cur_required() && info.breakpoint_node && + (info.nrn_cur_required() && info.breakpoint_node && has_incompatible_nodes(*info.breakpoint_node)); emit_fallback[BlockType::State] = emit_fallback[BlockType::State] || - (nrn_state_required() && info.nrn_state_block && + (info.nrn_state_required() && info.nrn_state_block && has_incompatible_nodes(*info.nrn_state_block)); @@ -674,7 +674,7 @@ void CodegenIspcVisitor::print_block_wrappers_initial_equation_state() { print_wrapper_routine(naming::NRN_INIT_METHOD, BlockType::Initial); } - if (nrn_cur_required()) { + if (info.nrn_cur_required()) { if (emit_fallback[BlockType::Equation]) { logger->warn("Falling back to C backend for emitting breakpoint block"); fallback_codegen.print_nrn_cur(); @@ -683,7 +683,7 @@ void CodegenIspcVisitor::print_block_wrappers_initial_equation_state() { } } - if (nrn_state_required()) { + if (info.nrn_state_required()) { if (emit_fallback[BlockType::State]) { logger->warn("Falling back to C backend for emitting state block"); fallback_codegen.print_nrn_state(); diff --git a/src/codegen/codegen_naming.hpp b/src/codegen/codegen_naming.hpp index 6d8875a000..e1cbfaf6f0 100644 --- a/src/codegen/codegen_naming.hpp +++ b/src/codegen/codegen_naming.hpp @@ -80,6 +80,12 @@ static constexpr char VOLTAGE_UNUSED_VARIABLE[] = "v_unused"; /// variable t indicating last execution time of net receive block static constexpr char T_SAVE_VARIABLE[] = "tsave"; +/// global variable celsius +static constexpr char CELSIUS_VARIABLE[] = "celsius"; + +/// global variable second_order +static constexpr char SECOND_ORDER_VARIABLE[] = "secondorder"; + /// shadow rhs variable in neuron thread structure static constexpr char NTHREAD_RHS_SHADOW[] = "_shadow_rhs"; diff --git a/src/codegen/llvm/CMakeLists.txt b/src/codegen/llvm/CMakeLists.txt new file mode 100644 index 0000000000..b927475f15 --- /dev/null +++ b/src/codegen/llvm/CMakeLists.txt @@ -0,0 +1,50 @@ +# ============================================================================= +# Codegen sources +# ============================================================================= +set(LLVM_CODEGEN_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/codegen_llvm_visitor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/codegen_llvm_visitor.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/codegen_llvm_helper_visitor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/codegen_llvm_helper_visitor.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/llvm_debug_builder.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/llvm_debug_builder.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/llvm_ir_builder.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/llvm_ir_builder.hpp) + +# ============================================================================= +# LLVM codegen library and executable +# ============================================================================= + +include_directories(${LLVM_INCLUDE_DIRS}) +add_library(runner_obj OBJECT ${LLVM_CODEGEN_SOURCE_FILES}) +add_dependencies(runner_obj lexer_obj) +set_property(TARGET runner_obj PROPERTY POSITION_INDEPENDENT_CODE ON) + +add_library(llvm_codegen STATIC $) +add_dependencies(llvm_codegen lexer util visitor) + +if(NOT NMODL_AS_SUBPROJECT) + add_executable(nmodl_llvm_runner main.cpp) + + target_link_libraries( + nmodl_llvm_runner + llvm_benchmark + llvm_codegen + codegen + visitor + symtab + lexer + util + test_util + printer + ${NMODL_WRAPPER_LIBS} + ${LLVM_LIBS_TO_LINK}) +endif() + +# ============================================================================= +# Install executable +# ============================================================================= + +if(NOT NMODL_AS_SUBPROJECT) + install(TARGETS nmodl_llvm_runner DESTINATION ${NMODL_INSTALL_DIR_SUFFIX}bin) +endif() diff --git a/src/codegen/llvm/codegen_llvm_helper_visitor.cpp b/src/codegen/llvm/codegen_llvm_helper_visitor.cpp new file mode 100644 index 0000000000..943bf06969 --- /dev/null +++ b/src/codegen/llvm/codegen_llvm_helper_visitor.cpp @@ -0,0 +1,851 @@ + +/************************************************************************* + * Copyright (C) 2018-2019 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include "codegen_llvm_helper_visitor.hpp" + +#include "ast/all.hpp" +#include "codegen/codegen_helper_visitor.hpp" +#include "symtab/symbol_table.hpp" +#include "utils/logger.hpp" +#include "visitors/rename_visitor.hpp" +#include "visitors/visitor_utils.hpp" + +namespace nmodl { +namespace codegen { + +using namespace fmt::literals; + +using symtab::syminfo::Status; + +/// initialize static member variables +const ast::AstNodeType CodegenLLVMHelperVisitor::INTEGER_TYPE = ast::AstNodeType::INTEGER; +const ast::AstNodeType CodegenLLVMHelperVisitor::FLOAT_TYPE = ast::AstNodeType::DOUBLE; +const std::string CodegenLLVMHelperVisitor::NODECOUNT_VAR = "node_count"; +const std::string CodegenLLVMHelperVisitor::VOLTAGE_VAR = "voltage"; +const std::string CodegenLLVMHelperVisitor::NODE_INDEX_VAR = "node_index"; + +static constexpr const char epilogue_variable_prefix[] = "epilogue_"; + +/// Create asr::Varname node with given a given variable name +static ast::VarName* create_varname(const std::string& varname) { + return new ast::VarName(new ast::Name(new ast::String(varname)), nullptr, nullptr); +} + +/** + * Create initialization expression + * @param code Usually "id = 0" as a string + * @return Expression representing code + * \todo : we can not use `create_statement_as_expression` function because + * NMODL parser is using `ast::Double` type to represent all variables + * including Integer. See #542. + */ +static std::shared_ptr int_initialization_expression( + const std::string& induction_var, + int value = 0) { + // create id = 0 + const auto& id = create_varname(induction_var); + const auto& zero = new ast::Integer(value, nullptr); + return std::make_shared(id, ast::BinaryOperator(ast::BOP_ASSIGN), zero); +} + +/** + * \brief Create variable definition statement + * + * `LOCAL` variables in NMODL don't have type. These variables need + * to be defined with float type. Same for index, loop iteration and + * local variables. This helper function function is used to create + * type specific local variable. + * + * @param names Name of the variables to be defined + * @param type Type of the variables + * @return Statement defining variables + */ +static std::shared_ptr create_local_variable_statement( + const std::vector& names, + ast::AstNodeType type) { + /// create variables for the given name + ast::CodegenVarVector variables; + for (const auto& name: names) { + auto varname = new ast::Name(new ast::String(name)); + variables.emplace_back(new ast::CodegenVar(0, varname)); + } + auto var_type = new ast::CodegenVarType(type); + /// construct statement and return it + return std::make_shared(var_type, variables); +} + +/** + * \brief Create expression for a given NMODL code statement + * @param code NMODL code statement + * @return Expression representing given NMODL code + */ +static std::shared_ptr create_statement_as_expression(const std::string& code) { + const auto& statement = visitor::create_statement(code); + auto expr_statement = std::dynamic_pointer_cast(statement); + auto expr = expr_statement->get_expression()->clone(); + return std::make_shared(expr); +} + +/** + * \brief Create expression for given NMODL code expression + * @param code NMODL code expression + * @return Expression representing NMODL code + */ +std::shared_ptr create_expression(const std::string& code) { + /// as provided code is only expression and not a full statement, create + /// a temporary assignment statement + const auto& wrapped_expr = create_statement_as_expression("some_var = " + code); + /// now extract RHS (representing original code) and return it as expression + auto expr = std::dynamic_pointer_cast(wrapped_expr)->get_expression(); + auto rhs = std::dynamic_pointer_cast(expr)->get_rhs(); + return std::make_shared(rhs->clone()); +} + +CodegenFunctionVector CodegenLLVMHelperVisitor::get_codegen_functions(const ast::Program& node) { + const_cast(node).accept(*this); + return codegen_functions; +} + +/** + * \brief Add code generation function for FUNCTION or PROCEDURE block + * @param node AST node representing FUNCTION or PROCEDURE + * + * When we have a PROCEDURE or FUNCTION like + * + * \code{.mod} + * FUNCTION sum(x,y) { + * LOCAL res + * res = x + y + * sum = res + * } + * \endcode + * + * this gets typically converted to C/C++ code as: + * + * \code{.cpp} + * double sum(double x, double y) { + * double res; + * double ret_sum; + * res = x + y; + * ret_sum = res; + * return ret_sum; + * \endcode + * + * We perform following transformations so that code generation backends + * will have minimum logic: + * - Add type for the function arguments + * - Define variables and return variable + * - Add return type (int for PROCEDURE and double for FUNCTION) + */ +void CodegenLLVMHelperVisitor::create_function_for_node(ast::Block& node) { + /// name of the function from the node + std::string function_name = node.get_node_name(); + auto name = new ast::Name(new ast::String(function_name)); + + /// return variable name has "ret_" prefix + std::string return_var_name = "ret_{}"_format(function_name); + auto return_var = new ast::Name(new ast::String(return_var_name)); + + /// return type based on node type + ast::CodegenVarType* ret_var_type = nullptr; + if (node.get_node_type() == ast::AstNodeType::FUNCTION_BLOCK) { + ret_var_type = new ast::CodegenVarType(FLOAT_TYPE); + } else { + ret_var_type = new ast::CodegenVarType(INTEGER_TYPE); + } + + /// function body and it's statement, copy original block + auto block = node.get_statement_block()->clone(); + const auto& statements = block->get_statements(); + + /// convert local statement to codegenvar statement + convert_local_statement(*block); + + if (node.get_node_type() == ast::AstNodeType::PROCEDURE_BLOCK) { + block->insert_statement(statements.begin(), + std::make_shared( + int_initialization_expression(return_var_name))); + } + /// insert return variable at the start of the block + ast::CodegenVarVector codegen_vars; + codegen_vars.emplace_back(new ast::CodegenVar(0, return_var->clone())); + auto statement = std::make_shared(ret_var_type, codegen_vars); + block->insert_statement(statements.begin(), statement); + + /// add return statement + auto return_statement = new ast::CodegenReturnStatement(return_var); + block->emplace_back_statement(return_statement); + + /// prepare function arguments based original node arguments + ast::CodegenVarWithTypeVector arguments; + for (const auto& param: node.get_parameters()) { + /// create new type and name for creating new ast node + auto type = new ast::CodegenVarType(FLOAT_TYPE); + auto var = param->get_name()->clone(); + arguments.emplace_back(new ast::CodegenVarWithType(type, /*is_pointer=*/0, var)); + } + + /// return type of the function is same as return variable type + ast::CodegenVarType* fun_ret_type = ret_var_type->clone(); + + /// we have all information for code generation function, create a new node + /// which will be inserted later into AST + auto function = std::make_shared(fun_ret_type, name, arguments, block); + if (node.get_token()) { + function->set_token(*node.get_token()->clone()); + } + codegen_functions.push_back(function); +} + +/** + * \note : Order of variables is not important but we assume all pointers + * are added first and then scalar variables like t, dt, second_order etc. + * This order is assumed when we allocate data for integration testing + * and benchmarking purpose. See CodegenDataHelper::create_data(). + */ +std::shared_ptr CodegenLLVMHelperVisitor::create_instance_struct() { + ast::CodegenVarWithTypeVector codegen_vars; + + auto add_var_with_type = + [&](const std::string& name, const ast::AstNodeType type, int is_pointer) { + auto var_name = new ast::Name(new ast::String(name)); + auto var_type = new ast::CodegenVarType(type); + auto codegen_var = new ast::CodegenVarWithType(var_type, is_pointer, var_name); + codegen_vars.emplace_back(codegen_var); + }; + + /// float variables are standard pointers to float vectors + for (const auto& float_var: info.codegen_float_variables) { + add_var_with_type(float_var->get_name(), FLOAT_TYPE, /*is_pointer=*/1); + } + + /// int variables are pointers to indexes for other vectors + for (const auto& int_var: info.codegen_int_variables) { + add_var_with_type(int_var.symbol->get_name(), FLOAT_TYPE, /*is_pointer=*/1); + } + + // for integer variables, there should be index + for (const auto& int_var: info.codegen_int_variables) { + std::string var_name = int_var.symbol->get_name() + "_index"; + add_var_with_type(var_name, INTEGER_TYPE, /*is_pointer=*/1); + } + + // add voltage and node index + add_var_with_type(VOLTAGE_VAR, FLOAT_TYPE, /*is_pointer=*/1); + add_var_with_type(NODE_INDEX_VAR, INTEGER_TYPE, /*is_pointer=*/1); + + // add dt, t, celsius + add_var_with_type(naming::NTHREAD_T_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0); + add_var_with_type(naming::NTHREAD_DT_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0); + add_var_with_type(naming::CELSIUS_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0); + add_var_with_type(naming::SECOND_ORDER_VARIABLE, INTEGER_TYPE, /*is_pointer=*/0); + add_var_with_type(NODECOUNT_VAR, INTEGER_TYPE, /*is_pointer=*/0); + + return std::make_shared(codegen_vars); +} + +static void append_statements_from_block(ast::StatementVector& statements, + const std::shared_ptr& block) { + const auto& block_statements = block->get_statements(); + for (const auto& statement: block_statements) { + const auto& expression_statement = std::dynamic_pointer_cast( + statement); + if (!expression_statement || !expression_statement->get_expression()->is_solve_block()) + statements.push_back(statement); + } +} + +static std::shared_ptr create_atomic_statement( + std::string& ion_varname, + std::string& index_varname, + std::string& op_str, + std::string& rhs_str) { + // create lhs expression + auto varname = new ast::Name(new ast::String(ion_varname)); + auto index = new ast::Name(new ast::String(index_varname)); + auto lhs = std::make_shared(new ast::IndexedName(varname, index), + /*at=*/nullptr, + /*index=*/nullptr); + + auto op = ast::BinaryOperator(ast::string_to_binaryop(op_str)); + auto rhs = create_expression(rhs_str); + return std::make_shared(lhs, op, rhs); +} + +/** + * For a given block type, add read ion statements + * + * Depending upon the block type, we have to update read ion variables + * during code generation. Depending on block/procedure being printed, + * this method adds necessary read ion variable statements and also + * corresponding index calculation statements. Note that index statements + * are added separately at the beginning for just readability purpose. + * + * @param type The type of code block being generated + * @param int_variables Index variables to be created + * @param double_variables Floating point variables to be created + * @param index_statements Statements for loading indexes (typically for ions) + * @param body_statements main compute/update statements + * + * \todo After looking into mod2c and neuron implementation, it seems like + * Ode block type is not used. Need to look into implementation details. + * + * \todo Ion copy optimization is not implemented yet. This is currently + * implemented in C backend using `ion_read_statements_optimized()`. + */ +void CodegenLLVMHelperVisitor::ion_read_statements(BlockType type, + std::vector& int_variables, + std::vector& double_variables, + ast::StatementVector& index_statements, + ast::StatementVector& body_statements) { + /// create read ion and corresponding index statements + auto create_read_statements = [&](std::pair variable_names) { + // variable in current mechanism instance + std::string& varname = variable_names.first; + // ion variable to be read + std::string& ion_varname = variable_names.second; + // index for reading ion variable + std::string index_varname = "{}_id"_format(varname); + // first load the index + std::string index_statement = "{} = {}_index[id]"_format(index_varname, ion_varname); + // now assign the value + std::string read_statement = "{} = {}[{}]"_format(varname, ion_varname, index_varname); + // push index definition, index statement and actual read statement + int_variables.push_back(index_varname); + index_statements.push_back(visitor::create_statement(index_statement)); + body_statements.push_back(visitor::create_statement(read_statement)); + }; + + /// iterate over all ions and create statements for given block type + for (const auto& ion: info.ions) { + const std::string& name = ion.name; + for (const auto& var: ion.reads) { + if (type == BlockType::Ode && ion.is_ionic_conc(var) && info.state_variable(var)) { + continue; + } + auto variable_names = info.read_ion_variable_name(var); + create_read_statements(variable_names); + } + for (const auto& var: ion.writes) { + if (type == BlockType::Ode && ion.is_ionic_conc(var) && info.state_variable(var)) { + continue; + } + if (ion.is_ionic_conc(var)) { + auto variable_names = info.read_ion_variable_name(var); + create_read_statements(variable_names); + } + } + } +} + +/** + * For a given block type, add write ion statements + * + * Depending upon the block type, we have to update write ion variables + * during code generation. Depending on block/procedure being printed, + * this method adds necessary write ion variable statements and also + * corresponding index calculation statements. Note that index statements + * are added separately at the beginning for just readability purpose. + * + * @param type The type of code block being generated + * @param int_variables Index variables to be created + * @param double_variables Floating point variables to be created + * @param index_statements Statements for loading indexes (typically for ions) + * @param body_statements main compute/update statements + * + * \todo If intra or extra cellular ionic concentration is written + * then it requires call to `nrn_wrote_conc`. In C backend this is + * implemented in `ion_write_statements()` itself but this is not + * handled yet. + */ +void CodegenLLVMHelperVisitor::ion_write_statements(BlockType type, + std::vector& int_variables, + std::vector& double_variables, + ast::StatementVector& index_statements, + ast::StatementVector& body_statements) { + /// create write ion and corresponding index statements + auto create_write_statements = [&](std::string ion_varname, std::string op, std::string rhs) { + // index for writing ion variable + std::string index_varname = "{}_id"_format(ion_varname); + // load index + std::string index_statement = "{} = {}_index[id]"_format(index_varname, ion_varname); + // push index definition, index statement and actual write statement + int_variables.push_back(index_varname); + index_statements.push_back(visitor::create_statement(index_statement)); + // pass ion variable to write and its index + body_statements.push_back(create_atomic_statement(ion_varname, index_varname, op, rhs)); + }; + + /// iterate over all ions and create write ion statements for given block type + for (const auto& ion: info.ions) { + std::string concentration; + std::string name = ion.name; + for (const auto& var: ion.writes) { + auto variable_names = info.write_ion_variable_name(var); + /// ionic currents are accumulated + if (ion.is_ionic_current(var)) { + if (type == BlockType::Equation) { + std::string current = info.breakpoint_current(var); + std::string lhs = variable_names.first; + std::string op = "+="; + std::string rhs = current; + // for synapse type + if (info.point_process) { + auto area = codegen::naming::NODE_AREA_VARIABLE; + rhs += "*(1.e2/{})"_format(area); + } + create_write_statements(lhs, op, rhs); + } + } else { + if (!ion.is_rev_potential(var)) { + concentration = var; + } + std::string lhs = variable_names.first; + std::string op = "="; + std::string rhs = variable_names.second; + create_write_statements(lhs, op, rhs); + } + } + + /// still need to handle, need to define easy to use API + if (type == BlockType::Initial && !concentration.empty()) { + int index = 0; + if (ion.is_intra_cell_conc(concentration)) { + index = 1; + } else if (ion.is_extra_cell_conc(concentration)) { + index = 2; + } else { + /// \todo Unhandled case also in neuron implementation + throw std::logic_error("codegen error for {} ion"_format(ion.name)); + } + std::string ion_type_name = "{}_type"_format(ion.name); + std::string lhs = "int {}"_format(ion_type_name); + std::string op = "="; + std::string rhs = ion_type_name; + create_write_statements(lhs, op, rhs); + logger->error("conc_write_statement() call is required but it's not supported"); + } + } +} + +/** + * Convert variables in given node to instance variables + * + * For code generation, variables of type range, assigned, state or parameter+range + * needs to be converted to instance variable i.e. they need to be accessed with + * loop index variable. For example, `h` variables needs to be converted to `h[id]`. + * + * @param node Ast node under which variables to be converted to instance type + */ +void CodegenLLVMHelperVisitor::convert_to_instance_variable(ast::Node& node, + std::string& index_var) { + /// collect all variables in the node of type ast::VarName + auto variables = collect_nodes(node, {ast::AstNodeType::VAR_NAME}); + for (const auto& v: variables) { + auto variable = std::dynamic_pointer_cast(v); + auto variable_name = variable->get_node_name(); + + /// all instance variables defined in the mod file should be converted to + /// indexed variables based on the loop iteration variable + if (info.is_an_instance_variable(variable_name)) { + auto name = variable->get_name()->clone(); + auto index = new ast::Name(new ast::String(index_var)); + auto indexed_name = std::make_shared(name, index); + variable->set_name(indexed_name); + } + + /// instance_var_helper check of instance variables from mod file as well + /// as extra variables like ion index variables added for code generation + if (instance_var_helper.is_an_instance_variable(variable_name)) { + auto name = new ast::Name(new ast::String(MECH_INSTANCE_VAR)); + auto var = std::make_shared(name, variable->clone()); + variable->set_name(var); + } + } +} + +/** + * \brief Visit StatementBlock and convert Local statement for code generation + * @param node AST node representing Statement block + * + * Statement blocks can have LOCAL statement and if it exist it's typically + * first statement in the vector. We have to remove LOCAL statement and convert + * it to CodegenVarListStatement that will represent all variables as double. + */ +void CodegenLLVMHelperVisitor::convert_local_statement(ast::StatementBlock& node) { + /// collect all local statement block + const auto& statements = collect_nodes(node, {ast::AstNodeType::LOCAL_LIST_STATEMENT}); + + /// iterate over all statements and replace each with codegen variable + for (const auto& statement: statements) { + const auto& local_statement = std::dynamic_pointer_cast(statement); + + /// create codegen variables from local variables + /// clone variable to make new independent statement + ast::CodegenVarVector variables; + for (const auto& var: local_statement->get_variables()) { + variables.emplace_back(new ast::CodegenVar(0, var->get_name()->clone())); + } + + /// remove local list statement now + std::unordered_set to_delete({local_statement.get()}); + /// local list statement is enclosed in statement block + const auto& parent_node = dynamic_cast(local_statement->get_parent()); + parent_node->erase_statement(to_delete); + + /// create new codegen variable statement and insert at the beginning of the block + auto type = new ast::CodegenVarType(FLOAT_TYPE); + auto new_statement = std::make_shared(type, variables); + const auto& statements = parent_node->get_statements(); + parent_node->insert_statement(statements.begin(), new_statement); + } +} + +/** + * \brief Visit StatementBlock and rename all LOCAL variables + * @param node AST node representing Statement block + * + * Statement block in remainder loop will have same LOCAL variables from + * main loop. In order to avoid conflict during lookup, rename each local + * variable by appending unique number. The number used as suffix is just + * a counter used for Statement block. + */ +void CodegenLLVMHelperVisitor::rename_local_variables(ast::StatementBlock& node) { + /// local block counter just to append unique number + static int local_block_counter = 1; + + /// collect all local statement block + const auto& statements = collect_nodes(node, {ast::AstNodeType::LOCAL_LIST_STATEMENT}); + + /// iterate over each statement and rename all variables + for (const auto& statement: statements) { + const auto& local_statement = std::dynamic_pointer_cast(statement); + + /// rename local variable in entire statement block + for (auto& var: local_statement->get_variables()) { + std::string old_name = var->get_node_name(); + std::string new_name = "{}_{}"_format(old_name, local_block_counter); + visitor::RenameVisitor(old_name, new_name).visit_statement_block(node); + } + } + + /// make it unique for next statement block + local_block_counter++; +} + + +void CodegenLLVMHelperVisitor::visit_procedure_block(ast::ProcedureBlock& node) { + node.visit_children(*this); + create_function_for_node(node); +} + +void CodegenLLVMHelperVisitor::visit_function_block(ast::FunctionBlock& node) { + node.visit_children(*this); + create_function_for_node(node); +} + +/** + * Create loop increment expression + * \todo : llvm.vscale.i32 is currently hardcoded. This can be done in a more elegant way. + */ +static std::shared_ptr loop_increment_expression(const std::string& induction_var, + int vector_width, + bool scalable) { + const auto& id = create_varname(induction_var); + if (scalable) { + // For scalable vectorized code generation, the increment is + // id = id + vscale * vector_width + const auto& call_for_factor = + new ast::FunctionCall(new ast::Name(new ast::String("llvm.vscale.i32")), {}); + const auto& min_vector_width = new ast::Integer(vector_width, /*macro=*/nullptr); + const auto& actual_width = new ast::BinaryExpression( + call_for_factor, ast::BinaryOperator(ast::BOP_MULTIPLICATION), min_vector_width); + const auto& inc_expr = + new ast::BinaryExpression(id, ast::BinaryOperator(ast::BOP_ADDITION), actual_width); + return std::make_shared(id->clone(), + ast::BinaryOperator(ast::BOP_ASSIGN), + inc_expr); + } + + // Otherwise, the increment is + // id = id + vector_width + const auto& increment = new ast::Integer(vector_width, /*macro=*/nullptr); + const auto& inc_expr = + new ast::BinaryExpression(id, ast::BinaryOperator(ast::BOP_ADDITION), increment); + return std::make_shared(id->clone(), + ast::BinaryOperator(ast::BOP_ASSIGN), + inc_expr); +} + +/** + * Create loop count comparison expression + * + * Serial loop: `id < node_count` + * Fixed vector width loop : `id < node_count - (vector_width - 1)` + * Scalable vector width loop : `id < node_count - (vscale * vector_width - 1)` + * + * \todo : llvm.vscale.i32 is currently hardcoded. This can be done in a more elegant way. + */ +static std::shared_ptr loop_count_expression(const std::string& induction_var, + const std::string& node_count, + int vector_width, + bool scalable) { + const auto& id = create_varname(induction_var); + const auto& mech_node_count = create_varname(node_count); + + // For non-vectorised loop, the condition is id < mech->node_count + if (vector_width == 1) { + return std::make_shared(id->clone(), + ast::BinaryOperator(ast::BOP_LESS), + mech_node_count); + } + + // For fixed vector width, the condition is id < mech->node_count - vector_width + 1 + if (!scalable) { + const auto& remainder = new ast::Integer(vector_width - 1, /*macro=*/nullptr); + const auto& count = new ast::BinaryExpression(mech_node_count, + ast::BinaryOperator(ast::BOP_SUBTRACTION), + remainder); + return std::make_shared(id->clone(), + ast::BinaryOperator(ast::BOP_LESS), + count); + } + + // For scalable vector width, the condition is id < mech->node_count - vscale * vector_width + 1 + const auto& call_for_factor = + new ast::FunctionCall(new ast::Name(new ast::String("llvm.vscale.i32")), {}); + const auto& min_vector_width = new ast::Integer(vector_width, /*macro=*/nullptr); + const auto& actual_width = new ast::BinaryExpression( + call_for_factor, ast::BinaryOperator(ast::BOP_MULTIPLICATION), min_vector_width); + const auto& one = new ast::Integer(1, /*macro=*/nullptr); + const auto& remainder = + new ast::BinaryExpression(actual_width, ast::BinaryOperator(ast::BOP_SUBTRACTION), one); + const auto& count = new ast::BinaryExpression(mech_node_count, + ast::BinaryOperator(ast::BOP_SUBTRACTION), + remainder); + return std::make_shared(id->clone(), + ast::BinaryOperator(ast::BOP_LESS), + count); +} + +/** + * \brief Convert ast::NrnStateBlock to corresponding code generation function nrn_state + * @param node AST node representing ast::NrnStateBlock + * + * Solver passes converts DERIVATIVE block from MOD into ast::NrnStateBlock node + * that represent `nrn_state` function in the generated CPP code. To help this + * code generation, we perform various transformation on ast::NrnStateBlock and + * create new code generation function. + */ +void CodegenLLVMHelperVisitor::visit_nrn_state_block(ast::NrnStateBlock& node) { + /// statements for new function to be generated + ast::StatementVector function_statements; + + /// create variable definition for loop index and insert at the beginning + std::string loop_index_var = "id"; + std::vector induction_variables{"id"}; + function_statements.push_back( + create_local_variable_statement(induction_variables, INTEGER_TYPE)); + + /// create vectors of local variables that would be used in compute part + std::vector int_variables{"node_id"}; + std::vector double_variables{"v"}; + + /// create now main compute part : for loop over channel instances + + /// loop body : initialization + solve blocks + ast::StatementVector loop_def_statements; + ast::StatementVector loop_index_statements; + ast::StatementVector loop_body_statements; + { + /// access node index and corresponding voltage + loop_index_statements.push_back( + visitor::create_statement("node_id = node_index[{}]"_format(INDUCTION_VAR))); + loop_body_statements.push_back( + visitor::create_statement("v = {}[node_id]"_format(VOLTAGE_VAR))); + + /// read ion variables + ion_read_statements(BlockType::State, + int_variables, + double_variables, + loop_index_statements, + loop_body_statements); + + /// main compute node : extract solution expressions from the derivative block + const auto& solutions = collect_nodes(node, {ast::AstNodeType::SOLUTION_EXPRESSION}); + for (const auto& statement: solutions) { + const auto& solution = std::dynamic_pointer_cast(statement); + const auto& block = std::dynamic_pointer_cast( + solution->get_node_to_solve()); + append_statements_from_block(loop_body_statements, block); + } + + /// add breakpoint block if no current + if (info.currents.empty() && info.breakpoint_node != nullptr) { + auto block = info.breakpoint_node->get_statement_block(); + append_statements_from_block(loop_body_statements, block); + } + + /// write ion statements + ion_write_statements(BlockType::State, + int_variables, + double_variables, + loop_index_statements, + loop_body_statements); + + // \todo handle process_shadow_update_statement and wrote_conc_call yet + } + + ast::StatementVector loop_body; + loop_body.insert(loop_body.end(), loop_def_statements.begin(), loop_def_statements.end()); + loop_body.insert(loop_body.end(), loop_index_statements.begin(), loop_index_statements.end()); + loop_body.insert(loop_body.end(), loop_body_statements.begin(), loop_body_statements.end()); + + /// now construct a new code block which will become the body of the loop + auto loop_block = std::make_shared(loop_body); + + /// declare main FOR loop local variables + function_statements.push_back(create_local_variable_statement(int_variables, INTEGER_TYPE)); + function_statements.push_back(create_local_variable_statement(double_variables, FLOAT_TYPE)); + + /// main loop possibly vectorized on vector_width + { + /// loop constructs : initialization, condition and increment + const auto& initialization = int_initialization_expression(INDUCTION_VAR); + const auto& condition = + loop_count_expression(INDUCTION_VAR, NODECOUNT_VAR, vector_width, scalable); + const auto& increment = loop_increment_expression(INDUCTION_VAR, vector_width, scalable); + + /// clone it + auto local_loop_block = std::shared_ptr(loop_block->clone()); + + /// convert local statement to codegenvar statement + convert_local_statement(*local_loop_block); + + auto for_loop_statement_main = std::make_shared(initialization, + condition, + increment, + local_loop_block); + + /// convert all variables inside loop body to instance variables + convert_to_instance_variable(*for_loop_statement_main, loop_index_var); + + /// loop itself becomes one of the statement in the function + function_statements.push_back(for_loop_statement_main); + } + + /// vectors containing renamed FOR loop local variables + std::vector renamed_int_variables; + std::vector renamed_double_variables; + + /// remainder loop possibly vectorized on vector_width + if (vector_width > 1) { + /// loop constructs : initialization, condition and increment + const auto& condition = loop_count_expression(INDUCTION_VAR, + NODECOUNT_VAR, + /*vector_width=*/1, + /*scalable=*/false); + const auto& increment = + loop_increment_expression(INDUCTION_VAR, /*vector_width=*/1, /*scalable=*/false); + + /// rename local variables to avoid conflict with main loop + rename_local_variables(*loop_block); + + /// convert local statement to codegenvar statement + convert_local_statement(*loop_block); + + auto for_loop_statement_remainder = + std::make_shared(nullptr, condition, increment, loop_block); + + const auto& loop_statements = for_loop_statement_remainder->get_statement_block(); + // \todo: Change RenameVisitor to take a vector of names to which it would append a single + // prefix. + for (const auto& name: int_variables) { + std::string new_name = epilogue_variable_prefix + name; + renamed_int_variables.push_back(new_name); + visitor::RenameVisitor v(name, new_name); + loop_statements->accept(v); + } + for (const auto& name: double_variables) { + std::string new_name = epilogue_variable_prefix + name; + renamed_double_variables.push_back(new_name); + visitor::RenameVisitor v(name, epilogue_variable_prefix + name); + loop_statements->accept(v); + } + + /// declare remainder FOR loop local variables + function_statements.push_back( + create_local_variable_statement(renamed_int_variables, INTEGER_TYPE)); + function_statements.push_back( + create_local_variable_statement(renamed_double_variables, FLOAT_TYPE)); + + /// convert all variables inside loop body to instance variables + convert_to_instance_variable(*for_loop_statement_remainder, loop_index_var); + + /// loop itself becomes one of the statement in the function + function_statements.push_back(for_loop_statement_remainder); + } + + /// new block for the function + auto function_block = new ast::StatementBlock(function_statements); + + /// name of the function and it's return type + std::string function_name = "nrn_state_" + stringutils::tolower(info.mod_suffix); + auto name = new ast::Name(new ast::String(function_name)); + auto return_type = new ast::CodegenVarType(ast::AstNodeType::VOID); + + /// \todo : currently there are no arguments + ast::CodegenVarWithTypeVector code_arguments; + + auto instance_var_type = new ast::CodegenVarType(ast::AstNodeType::INSTANCE_STRUCT); + auto instance_var_name = new ast::Name(new ast::String(MECH_INSTANCE_VAR)); + auto instance_var = new ast::CodegenVarWithType(instance_var_type, 1, instance_var_name); + code_arguments.emplace_back(instance_var); + + /// finally, create new function + auto function = + std::make_shared(return_type, name, code_arguments, function_block); + codegen_functions.push_back(function); + + std::cout << nmodl::to_nmodl(function) << std::endl; +} + +void CodegenLLVMHelperVisitor::remove_inlined_nodes(ast::Program& node) { + auto program_symtab = node.get_model_symbol_table(); + const auto& func_proc_nodes = + collect_nodes(node, {ast::AstNodeType::FUNCTION_BLOCK, ast::AstNodeType::PROCEDURE_BLOCK}); + std::unordered_set nodes_to_erase; + for (const auto& ast_node: func_proc_nodes) { + if (program_symtab->lookup(ast_node->get_node_name()) + .get() + ->has_all_status(Status::inlined)) { + nodes_to_erase.insert(static_cast(ast_node.get())); + } + } + node.erase_node(nodes_to_erase); +} + +void CodegenLLVMHelperVisitor::visit_program(ast::Program& node) { + /// run codegen helper visitor to collect information + CodegenHelperVisitor v; + info = v.analyze(node); + + instance_var_helper.instance = create_instance_struct(); + node.emplace_back_node(instance_var_helper.instance); + + logger->info("Running CodegenLLVMHelperVisitor"); + remove_inlined_nodes(node); + node.visit_children(*this); + for (auto& fun: codegen_functions) { + node.emplace_back_node(fun); + } +} + + +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/llvm/codegen_llvm_helper_visitor.hpp b/src/codegen/llvm/codegen_llvm_helper_visitor.hpp new file mode 100644 index 0000000000..bcca84ae38 --- /dev/null +++ b/src/codegen/llvm/codegen_llvm_helper_visitor.hpp @@ -0,0 +1,185 @@ +/************************************************************************* + * Copyright (C) 2018-2019 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#pragma once + +/** + * \file + * \brief \copybrief nmodl::codegen::CodegenLLVMHelperVisitor + */ + +#include + +#include "ast/instance_struct.hpp" +#include "codegen/codegen_info.hpp" +#include "symtab/symbol_table.hpp" +#include "visitors/ast_visitor.hpp" + +namespace nmodl { +namespace codegen { + +using namespace fmt::literals; +typedef std::vector> CodegenFunctionVector; + +/** + * @addtogroup llvm_codegen_details + * @{ + */ + +/** + * \class InstanceVarHelper + * \brief Helper to query instance variables information + * + * For LLVM IR generation we need to know the variable, it's type and + * location in the instance structure. This helper provides convenient + * functions to query this information. + */ +struct InstanceVarHelper { + /// pointer to instance node in the AST + std::shared_ptr instance; + + /// find variable with given name and return the iterator + ast::CodegenVarWithTypeVector::const_iterator find_variable( + const ast::CodegenVarWithTypeVector& vars, + const std::string& name) { + return find_if(vars.begin(), + vars.end(), + [&](const std::shared_ptr& v) { + return v->get_node_name() == name; + }); + } + + /// check if given variable is instance variable + bool is_an_instance_variable(const std::string& name) { + const auto& vars = instance->get_codegen_vars(); + return find_variable(vars, name) != vars.end(); + } + + /// return codegen variable with a given name + const std::shared_ptr& get_variable(const std::string& name) { + const auto& vars = instance->get_codegen_vars(); + auto it = find_variable(vars, name); + if (it == vars.end()) { + throw std::runtime_error("Can not find variable with name {}"_format(name)); + } + return *it; + } + + /// return position of the variable in the instance structure + int get_variable_index(const std::string& name) { + const auto& vars = instance->get_codegen_vars(); + auto it = find_variable(vars, name); + if (it == vars.end()) { + throw std::runtime_error("Can not find codegen variable with name {}"_format(name)); + } + return (it - vars.begin()); + } +}; + + +/** + * \class CodegenLLVMHelperVisitor + * \brief Helper visitor for AST information to help code generation backends + * + * Code generation backends convert NMODL AST to C++ code. But during this + * C++ code generation, various transformations happens and final code generated + * is quite different / large than actual kernel represented in MOD file ro + * NMODL AST. + * + * Currently, these transformations are embedded into code generation backends + * like ast::CodegenCVisitor. If we have to generate code for new simulator, there + * will be duplication of these transformations. Also, for completely new + * backends like NEURON simulator or SIMD library, we will have code duplication. + * + * In order to avoid this, we perform maximum transformations in this visitor. + * Currently we focus on transformations that will help LLVM backend but later + * these will be common across all backends. + */ +class CodegenLLVMHelperVisitor: public visitor::AstVisitor { + /// explicit vectorisation width + int vector_width; + + /// target scalable ISAs + bool scalable; + + /// newly generated code generation specific functions + CodegenFunctionVector codegen_functions; + + /// ast information for code generation + codegen::CodegenInfo info; + + /// mechanism data helper + InstanceVarHelper instance_var_helper; + + /// name of the mechanism instance parameter + const std::string MECH_INSTANCE_VAR = "mech"; + + /// name of induction variable used in the kernel. + const std::string INDUCTION_VAR = "id"; + + /// create new function for FUNCTION or PROCEDURE block + void create_function_for_node(ast::Block& node); + + /// create new InstanceStruct + std::shared_ptr create_instance_struct(); + + public: + /// default integer and float node type + static const ast::AstNodeType INTEGER_TYPE; + static const ast::AstNodeType FLOAT_TYPE; + + // node count, voltage and node index variables + static const std::string NODECOUNT_VAR; + static const std::string VOLTAGE_VAR; + static const std::string NODE_INDEX_VAR; + + CodegenLLVMHelperVisitor(int vector_width, + bool scalable = false) + : vector_width(vector_width) + , scalable(scalable) {} + + const InstanceVarHelper& get_instance_var_helper() { + return instance_var_helper; + } + + std::string get_kernel_id() { + return INDUCTION_VAR; + } + + /// run visitor and return code generation functions + CodegenFunctionVector get_codegen_functions(const ast::Program& node); + + void ion_read_statements(BlockType type, + std::vector& int_variables, + std::vector& double_variables, + ast::StatementVector& index_statements, + ast::StatementVector& body_statements); + + void ion_write_statements(BlockType type, + std::vector& int_variables, + std::vector& double_variables, + ast::StatementVector& index_statements, + ast::StatementVector& body_statements); + + void convert_to_instance_variable(ast::Node& node, std::string& index_var); + + void convert_local_statement(ast::StatementBlock& node); + void rename_local_variables(ast::StatementBlock& node); + + /// Remove Function and Procedure blocks from the node since they are already inlined + void remove_inlined_nodes(ast::Program& node); + + void visit_procedure_block(ast::ProcedureBlock& node) override; + void visit_function_block(ast::FunctionBlock& node) override; + void visit_nrn_state_block(ast::NrnStateBlock& node) override; + void visit_program(ast::Program& node) override; +}; + +/** @} */ // end of llvm_codegen_details + +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/llvm/codegen_llvm_visitor.cpp b/src/codegen/llvm/codegen_llvm_visitor.cpp new file mode 100644 index 0000000000..131661cf35 --- /dev/null +++ b/src/codegen/llvm/codegen_llvm_visitor.cpp @@ -0,0 +1,971 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include "codegen/llvm/codegen_llvm_visitor.hpp" + +#include "ast/all.hpp" +#include "visitors/rename_visitor.hpp" +#include "visitors/visitor_utils.hpp" + +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Host.h" +#include "llvm/Support/ToolOutputFile.h" + +#if LLVM_VERSION_MAJOR >= 13 +#include "llvm/CodeGen/ReplaceWithVeclib.h" +#endif + +namespace nmodl { +namespace codegen { + + +static constexpr const char instance_struct_type_name[] = "__instance_var__type"; + + +/****************************************************************************************/ +/* Helper routines */ +/****************************************************************************************/ + +/// A utility to check for supported Statement AST nodes. +static bool is_supported_statement(const ast::Statement& statement) { + return statement.is_codegen_atomic_statement() || statement.is_codegen_for_statement() || + statement.is_if_statement() || statement.is_codegen_return_statement() || + statement.is_codegen_var_list_statement() || statement.is_expression_statement() || + statement.is_while_statement(); +} + +/// A utility to check that the kernel body can be vectorised. +static bool can_vectorize(const ast::CodegenForStatement& statement, symtab::SymbolTable* sym_tab) { + // Check that function calls are made to external methods only. + const auto& function_calls = collect_nodes(statement, {ast::AstNodeType::FUNCTION_CALL}); + for (const auto& call: function_calls) { + const auto& name = call->get_node_name(); + auto symbol = sym_tab->lookup(name); + if (symbol && !symbol->has_any_property(symtab::syminfo::NmodlType::extern_method)) + return false; + } + + // Check for simple supported control flow in the kernel (single if/else statement). + const std::vector supported_control_flow = {ast::AstNodeType::IF_STATEMENT}; + const auto& supported = collect_nodes(statement, supported_control_flow); + + // Check for unsupported control flow statements. + const std::vector unsupported_nodes = {ast::AstNodeType::ELSE_IF_STATEMENT}; + const auto& unsupported = collect_nodes(statement, unsupported_nodes); + + return unsupported.empty() && supported.size() <= 1; +} + +#if LLVM_VERSION_MAJOR >= 13 +void CodegenLLVMVisitor::add_vectorizable_functions_from_vec_lib(llvm::TargetLibraryInfoImpl& tli, + llvm::Triple& triple) { + // Since LLVM does not support SLEEF as a vector library yet, process it separately. + if (vector_library == "SLEEF") { + // Populate function definitions of only exp and pow (for now) +#define FIXED(w) llvm::ElementCount::getFixed(w) +#define DISPATCH(func, vec_func, width) {func, vec_func, width}, + const llvm::VecDesc aarch64_functions[] = { + // clang-format off + DISPATCH("llvm.exp.f32", "_ZGVnN4v_expf", FIXED(4)) + DISPATCH("llvm.exp.f64", "_ZGVnN2v_exp", FIXED(2)) + DISPATCH("llvm.pow.f32", "_ZGVnN4vv_powf", FIXED(4)) + DISPATCH("llvm.pow.f64", "_ZGVnN2vv_pow", FIXED(2)) + // clang-format on + }; + const llvm::VecDesc x86_functions[] = { + // clang-format off + DISPATCH("llvm.exp.f64", "_ZGVbN2v_exp", FIXED(2)) + DISPATCH("llvm.exp.f64", "_ZGVdN4v_exp", FIXED(4)) + DISPATCH("llvm.exp.f64", "_ZGVeN8v_exp", FIXED(8)) + DISPATCH("llvm.pow.f64", "_ZGVbN2vv_pow", FIXED(2)) + DISPATCH("llvm.pow.f64", "_ZGVdN4vv_pow", FIXED(4)) + DISPATCH("llvm.pow.f64", "_ZGVeN8vv_pow", FIXED(8)) + // clang-format on + }; +#undef DISPATCH + + if (triple.isAArch64()) { + tli.addVectorizableFunctions(aarch64_functions); + } + if (triple.isX86() && triple.isArch64Bit()) { + tli.addVectorizableFunctions(x86_functions); + } + + } else { + // A map to query vector library by its string value. + using VecLib = llvm::TargetLibraryInfoImpl::VectorLibrary; + static const std::map llvm_supported_vector_libraries = { + {"Accelerate", VecLib::Accelerate}, + {"libmvec", VecLib::LIBMVEC_X86}, + {"libsystem_m", VecLib ::DarwinLibSystemM}, + {"MASSV", VecLib::MASSV}, + {"none", VecLib::NoLibrary}, + {"SVML", VecLib::SVML}}; + const auto& library = llvm_supported_vector_libraries.find(vector_library); + if (library == llvm_supported_vector_libraries.end()) + throw std::runtime_error("Error: unknown vector library - " + vector_library + "\n"); + + // Add vectorizable functions to the target library info. + switch (library->second) { + case VecLib::LIBMVEC_X86: + if (!triple.isX86() || !triple.isArch64Bit()) + break; + default: + tli.addVectorizableFunctionsFromVecLib(library->second); + break; + } + } +} +#endif + +llvm::Value* CodegenLLVMVisitor::accept_and_get(const std::shared_ptr& node) { + node->accept(*this); + return ir_builder.pop_last_value(); +} + +void CodegenLLVMVisitor::create_external_function_call(const std::string& name, + const ast::ExpressionVector& arguments) { + if (name == "printf") { + create_printf_call(arguments); + return; + } + + ValueVector argument_values; + TypeVector argument_types; + for (const auto& arg: arguments) { + llvm::Value* value = accept_and_get(arg); + llvm::Type* type = value->getType(); + argument_types.push_back(type); + argument_values.push_back(value); + } + ir_builder.create_intrinsic(name, argument_values, argument_types); +} + +void CodegenLLVMVisitor::create_function_call(llvm::Function* func, + const std::string& name, + const ast::ExpressionVector& arguments) { + // Check that function is called with the expected number of arguments. + if (!func->isVarArg() && arguments.size() != func->arg_size()) { + throw std::runtime_error("Error: Incorrect number of arguments passed"); + } + + // Pack function call arguments to vector and create a call instruction. + ValueVector argument_values; + argument_values.reserve(arguments.size()); + create_function_call_arguments(arguments, argument_values); + ir_builder.create_function_call(func, argument_values); +} + +void CodegenLLVMVisitor::create_function_call_arguments(const ast::ExpressionVector& arguments, + ValueVector& arg_values) { + for (const auto& arg: arguments) { + if (arg->is_string()) { + // If the argument is a string, create a global i8* variable with it. + auto string_arg = std::dynamic_pointer_cast(arg); + arg_values.push_back(ir_builder.create_global_string(*string_arg)); + } else { + llvm::Value* value = accept_and_get(arg); + arg_values.push_back(value); + } + } +} + +void CodegenLLVMVisitor::create_function_declaration(const ast::CodegenFunction& node) { + const auto& name = node.get_node_name(); + const auto& arguments = node.get_arguments(); + + // Procedure or function parameters are doubles by default. + TypeVector arg_types; + for (size_t i = 0; i < arguments.size(); ++i) + arg_types.push_back(get_codegen_var_type(*arguments[i]->get_type())); + + llvm::Type* return_type = get_codegen_var_type(*node.get_return_type()); + + // Create a function that is automatically inserted into module's symbol table. + auto func = + llvm::Function::Create(llvm::FunctionType::get(return_type, arg_types, /*isVarArg=*/false), + llvm::Function::ExternalLinkage, + name, + *module); + + // Add function debug information, with location information if it exists. + if (add_debug_information) { + if (node.get_token()) { + Location loc{node.get_token()->start_line(), node.get_token()->start_column()}; + debug_builder.add_function_debug_info(func, &loc); + } else { + debug_builder.add_function_debug_info(func); + } + } +} + +void CodegenLLVMVisitor::create_printf_call(const ast::ExpressionVector& arguments) { + // First, create printf declaration or insert it if it does not exit. + std::string name = "printf"; + llvm::Function* printf = module->getFunction(name); + if (!printf) { + llvm::FunctionType* printf_type = llvm::FunctionType::get(ir_builder.get_i32_type(), + ir_builder.get_i8_ptr_type(), + /*isVarArg=*/true); + + printf = + llvm::Function::Create(printf_type, llvm::Function::ExternalLinkage, name, *module); + } + + // Create a call instruction. + ValueVector argument_values; + argument_values.reserve(arguments.size()); + create_function_call_arguments(arguments, argument_values); + ir_builder.create_function_call(printf, argument_values, /*use_result=*/false); +} + +void CodegenLLVMVisitor::create_vectorized_control_flow_block(const ast::IfStatement& node) { + // Get the true mask from the condition statement. + llvm::Value* true_mask = accept_and_get(node.get_condition()); + + // Process the true block. + ir_builder.set_mask(true_mask); + node.get_statement_block()->accept(*this); + + // Note: by default, we do not support kernels with complicated control flow. This is checked + // prior to visiting 'CodegenForStatement`. + const auto& elses = node.get_elses(); + if (elses) { + // If `else` statement exists, invert the mask and proceed with code generation. + ir_builder.invert_mask(); + elses->get_statement_block()->accept(*this); + } + + // Clear the mask value. + ir_builder.clear_mask(); +} + +void CodegenLLVMVisitor::find_kernel_names(std::vector& container) { + auto& functions = module->getFunctionList(); + for (auto& func: functions) { + const std::string name = func.getName().str(); + if (is_kernel_function(name)) { + container.push_back(name); + } + } +} + +llvm::Type* CodegenLLVMVisitor::get_codegen_var_type(const ast::CodegenVarType& node) { + switch (node.get_type()) { + case ast::AstNodeType::BOOLEAN: + return ir_builder.get_boolean_type(); + case ast::AstNodeType::DOUBLE: + return ir_builder.get_fp_type(); + case ast::AstNodeType::INSTANCE_STRUCT: + return get_instance_struct_type(); + case ast::AstNodeType::INTEGER: + return ir_builder.get_i32_type(); + case ast::AstNodeType::VOID: + return ir_builder.get_void_type(); + default: + throw std::runtime_error("Error: expecting a type in CodegenVarType node\n"); + } +} + +llvm::Value* CodegenLLVMVisitor::get_index(const ast::IndexedName& node) { + // In NMODL, the index is either an integer expression or a named constant, such as "id". + llvm::Value* index_value = node.get_length()->is_name() + ? ir_builder.create_load(node.get_length()->get_node_name()) + : accept_and_get(node.get_length()); + return ir_builder.create_index(index_value); +} + +llvm::Type* CodegenLLVMVisitor::get_instance_struct_type() { + TypeVector member_types; + for (const auto& variable: instance_var_helper.instance->get_codegen_vars()) { + // Get the type information of the codegen variable. + const auto& is_pointer = variable->get_is_pointer(); + const auto& nmodl_type = variable->get_type()->get_type(); + + // Create the corresponding LLVM type. + switch (nmodl_type) { + case ast::AstNodeType::DOUBLE: + member_types.push_back(is_pointer ? ir_builder.get_fp_ptr_type() + : ir_builder.get_fp_type()); + break; + case ast::AstNodeType::INTEGER: + member_types.push_back(is_pointer ? ir_builder.get_i32_ptr_type() + : ir_builder.get_i32_type()); + break; + default: + throw std::runtime_error("Error: unsupported type found in instance struct\n"); + } + } + + return ir_builder.get_struct_ptr_type(mod_filename + instance_struct_type_name, member_types); +} + +int CodegenLLVMVisitor::get_num_elements(const ast::IndexedName& node) { + // First, verify if the length is an integer value. + const auto& integer = std::dynamic_pointer_cast(node.get_length()); + if (!integer) + throw std::runtime_error("Error: only integer length is supported\n"); + + // Check if the length value is a constant. + if (!integer->get_macro()) + return integer->get_value(); + + // Otherwise, the length is taken from the macro. + const auto& macro = sym_tab->lookup(integer->get_macro()->get_node_name()); + return static_cast(*macro->get_value()); +} + +/** + * Currently, functions are identified as compute kernels if they satisfy the following: + * 1. They have a void return type + * 2. They have a single argument + * 3. The argument is a struct type pointer + * This is not robust, and hence it would be better to find what functions are kernels on the NMODL + * AST side (e.g. via a flag, or via names list). + * + * \todo identify kernels on NMODL AST side. + */ +bool CodegenLLVMVisitor::is_kernel_function(const std::string& function_name) { + llvm::Function* function = module->getFunction(function_name); + if (!function) + throw std::runtime_error("Error: function " + function_name + " does not exist\n"); + + // By convention, only kernel functions have a return type of void and single argument. The + // number of arguments check is needed to avoid LLVM void intrinsics to be considered as + // kernels. + if (!function->getReturnType()->isVoidTy() || !llvm::hasSingleElement(function->args())) + return false; + + // Kernel's argument is a pointer to the instance struct type. + llvm::Type* arg_type = function->getArg(0)->getType(); + if (auto pointer_type = llvm::dyn_cast(arg_type)) { + if (pointer_type->getElementType()->isStructTy()) + return true; + } + return false; +} + +llvm::Value* CodegenLLVMVisitor::read_from_or_write_to_instance(const ast::CodegenInstanceVar& node, + llvm::Value* maybe_value_to_store) { + const auto& instance_name = node.get_instance_var()->get_node_name(); + const auto& member_node = node.get_member_var(); + const auto& member_name = member_node->get_node_name(); + + if (!instance_var_helper.is_an_instance_variable(member_name)) + throw std::runtime_error("Error: " + member_name + + " is not a member of the instance variable\n"); + + // Load the instance struct by its name. + llvm::Value* instance_ptr = ir_builder.create_load(instance_name); + + // Get the pointer to the specified member. + int member_index = instance_var_helper.get_variable_index(member_name); + llvm::Value* member_ptr = ir_builder.get_struct_member_ptr(instance_ptr, member_index); + + // Check if the member is scalar. Load the value or store to it straight away. Otherwise, we + // need some extra handling. + auto codegen_var_with_type = instance_var_helper.get_variable(member_name); + if (!codegen_var_with_type->get_is_pointer()) { + if (maybe_value_to_store) { + ir_builder.create_store(member_ptr, maybe_value_to_store); + return nullptr; + } else { + return ir_builder.create_load(member_ptr); + } + } + + // Check that the member is an indexed name indeed, and that it is indexed by a named constant + // (e.g. "id"). + const auto& member_var_name = std::dynamic_pointer_cast(member_node); + if (!member_var_name->get_name()->is_indexed_name()) + throw std::runtime_error("Error: " + member_name + " is not an IndexedName\n"); + + const auto& member_indexed_name = std::dynamic_pointer_cast( + member_var_name->get_name()); + if (!member_indexed_name->get_length()->is_name()) + throw std::runtime_error("Error: " + member_name + " must be indexed with a variable!"); + + // Get the index to the member and the id used to index it. + llvm::Value* i64_index = get_index(*member_indexed_name); + const std::string id = member_indexed_name->get_length()->get_node_name(); + + // Load the member of the instance struct. + llvm::Value* instance_member = ir_builder.create_load(member_ptr); + + // Create a pointer to the specified element of the struct member. + return ir_builder.load_to_or_store_from_array(id, + i64_index, + instance_member, + maybe_value_to_store); +} + +llvm::Value* CodegenLLVMVisitor::read_variable(const ast::VarName& node) { + const auto& identifier = node.get_name(); + + if (identifier->is_name()) { + return ir_builder.create_load(node.get_node_name(), + /*masked=*/ir_builder.generates_predicated_ir()); + } + + if (identifier->is_indexed_name()) { + const auto& indexed_name = std::dynamic_pointer_cast(identifier); + llvm::Value* index = get_index(*indexed_name); + return ir_builder.create_load_from_array(node.get_node_name(), index); + } + + if (identifier->is_codegen_instance_var()) { + const auto& instance_var = std::dynamic_pointer_cast(identifier); + return read_from_or_write_to_instance(*instance_var); + } + + throw std::runtime_error("Error: the type of '" + node.get_node_name() + + "' is not supported\n"); +} + +void CodegenLLVMVisitor::run_ir_opt_passes() { + // Run some common optimisation passes that are commonly suggested. + opt_pm.add(llvm::createInstructionCombiningPass()); + opt_pm.add(llvm::createReassociatePass()); + opt_pm.add(llvm::createGVNPass()); + opt_pm.add(llvm::createCFGSimplificationPass()); + + // Initialize pass manager. + opt_pm.doInitialization(); + + // Iterate over all functions and run the optimisation passes. + auto& functions = module->getFunctionList(); + for (auto& function: functions) { + llvm::verifyFunction(function); + opt_pm.run(function); + } + opt_pm.doFinalization(); +} + +void CodegenLLVMVisitor::write_to_variable(const ast::VarName& node, llvm::Value* value) { + const auto& identifier = node.get_name(); + if (!identifier->is_name() && !identifier->is_indexed_name() && + !identifier->is_codegen_instance_var()) { + throw std::runtime_error("Error: the type of '" + node.get_node_name() + + "' is not supported\n"); + } + + if (identifier->is_name()) { + ir_builder.create_store(node.get_node_name(), value); + } + + if (identifier->is_indexed_name()) { + const auto& indexed_name = std::dynamic_pointer_cast(identifier); + llvm::Value* index = get_index(*indexed_name); + ir_builder.create_store_to_array(node.get_node_name(), index, value); + } + + if (identifier->is_codegen_instance_var()) { + const auto& instance_var = std::dynamic_pointer_cast(identifier); + read_from_or_write_to_instance(*instance_var, value); + } +} + +void CodegenLLVMVisitor::wrap_kernel_functions() { + // First, identify all kernels. + std::vector kernel_names; + find_kernel_names(kernel_names); + + for (const auto& kernel_name: kernel_names) { + // Get the kernel function. + auto kernel = module->getFunction(kernel_name); + + // Create a wrapper void function that takes a void pointer as a single argument. + llvm::Type* i32_type = ir_builder.get_i32_type(); + llvm::Type* void_ptr_type = ir_builder.get_i8_ptr_type(); + llvm::Function* wrapper_func = llvm::Function::Create( + llvm::FunctionType::get(i32_type, {void_ptr_type}, /*isVarArg=*/false), + llvm::Function::ExternalLinkage, + "__" + kernel_name + "_wrapper", + *module); + + // Optionally, add debug information for the wrapper function. + if (add_debug_information) { + debug_builder.add_function_debug_info(wrapper_func); + } + + ir_builder.create_block_and_set_insertion_point(wrapper_func); + + // Proceed with bitcasting the void pointer to the struct pointer type, calling the kernel + // and adding a terminator. + llvm::Value* bitcasted = ir_builder.create_bitcast(wrapper_func->getArg(0), + kernel->getArg(0)->getType()); + ValueVector args; + args.push_back(bitcasted); + ir_builder.create_function_call(kernel, args, /*use_result=*/false); + + // Create a 0 return value and a return instruction. + ir_builder.create_i32_constant(0); + ir_builder.create_return(ir_builder.pop_last_value()); + } +} + + +/****************************************************************************************/ +/* Overloaded visitor routines */ +/****************************************************************************************/ + + +void CodegenLLVMVisitor::visit_binary_expression(const ast::BinaryExpression& node) { + const auto& op = node.get_op().get_value(); + + // Process rhs first, since lhs is handled differently for assignment and binary + // operators. + llvm::Value* rhs = accept_and_get(node.get_rhs()); + if (op == ast::BinaryOp::BOP_ASSIGN) { + auto var = dynamic_cast(node.get_lhs().get()); + if (!var) + throw std::runtime_error("Error: only 'VarName' assignment is supported\n"); + + write_to_variable(*var, rhs); + return; + } + + llvm::Value* lhs = accept_and_get(node.get_lhs()); + ir_builder.create_binary_op(lhs, rhs, op); +} + +void CodegenLLVMVisitor::visit_statement_block(const ast::StatementBlock& node) { + const auto& statements = node.get_statements(); + for (const auto& statement: statements) { + if (is_supported_statement(*statement)) + statement->accept(*this); + } +} + +void CodegenLLVMVisitor::visit_boolean(const ast::Boolean& node) { + ir_builder.create_boolean_constant(node.get_value()); +} + +/** + * Currently, this functions is very similar to visiting the binary operator. However, the + * difference here is that the writes to the LHS variable must be atomic. These has a particular + * use case in synapse kernels. For simplicity, we choose not to support atomic writes at this + * stage and emit a warning. + * + * \todo support this properly. + */ +void CodegenLLVMVisitor::visit_codegen_atomic_statement(const ast::CodegenAtomicStatement& node) { + if (vector_width > 1) + logger->warn("Atomic operations are not supported"); + + // Support only assignment for now. + llvm::Value* rhs = accept_and_get(node.get_rhs()); + if (node.get_atomic_op().get_value() != ast::BinaryOp::BOP_ASSIGN) + throw std::runtime_error( + "Error: only assignment is supported for CodegenAtomicStatement\n"); + const auto& var = dynamic_cast(node.get_lhs().get()); + if (!var) + throw std::runtime_error("Error: only 'VarName' assignment is supported\n"); + + // Process the assignment as if it was non-atomic. + if (vector_width > 1) + logger->warn("Treating write as non-atomic"); + write_to_variable(*var, rhs); +} + +// Generating FOR loop in LLVM IR creates the following structure: +// +// +---------------------------+ +// | | +// | | +// | br %cond | +// +---------------------------+ +// | +// V +// +-----------------------------+ +// | | +// | %cond = ... |<------+ +// | cond_br %cond, %body, %exit | | +// +-----------------------------+ | +// | | | +// | V | +// | +------------------------+ | +// | | | | +// | | br %inc | | +// | +------------------------+ | +// | | | +// | V | +// | +------------------------+ | +// | | | | +// | | br %cond | | +// | +------------------------+ | +// | | | +// | +---------------+ +// V +// +---------------------------+ +// | | +// +---------------------------+ +void CodegenLLVMVisitor::visit_codegen_for_statement(const ast::CodegenForStatement& node) { + // Condition and increment blocks must be scalar. + ir_builder.generate_scalar_ir(); + + // Get the current and the next blocks within the function. + llvm::BasicBlock* curr_block = ir_builder.get_current_block(); + llvm::BasicBlock* next = curr_block->getNextNode(); + llvm::Function* func = curr_block->getParent(); + + // Create the basic blocks for FOR loop. + llvm::BasicBlock* for_cond = + llvm::BasicBlock::Create(*context, /*Name=*/"for.cond", func, next); + llvm::BasicBlock* for_body = + llvm::BasicBlock::Create(*context, /*Name=*/"for.body", func, next); + llvm::BasicBlock* for_inc = llvm::BasicBlock::Create(*context, /*Name=*/"for.inc", func, next); + llvm::BasicBlock* exit = llvm::BasicBlock::Create(*context, /*Name=*/"for.exit", func, next); + + // First, initialize the loop in the same basic block. If processing the remainder of the loop, + // no initialization happens. + const auto& main_loop_initialization = node.get_initialization(); + if (main_loop_initialization) + main_loop_initialization->accept(*this); + + // Branch to condition basic block and insert condition code there. + ir_builder.create_br_and_set_insertion_point(for_cond); + + // Extract the condition to decide whether to branch to the loop body or loop exit. + llvm::Value* cond = accept_and_get(node.get_condition()); + llvm::BranchInst* loop_br = ir_builder.create_cond_br(cond, for_body, exit); + ir_builder.set_loop_metadata(loop_br); + ir_builder.set_insertion_point(for_body); + + // If not processing remainder of the loop, start vectorization. + if (vector_width > 1 && main_loop_initialization) + ir_builder.generate_vector_ir(); + + // Generate code for the loop body and create the basic block for the increment. + const auto& statement_block = node.get_statement_block(); + statement_block->accept(*this); + ir_builder.generate_scalar_ir(); + ir_builder.create_br_and_set_insertion_point(for_inc); + + // Process the increment. + node.get_increment()->accept(*this); + + // Create a branch to condition block, then generate exit code out of the loop. + ir_builder.create_br(for_cond); + ir_builder.set_insertion_point(exit); +} + + +void CodegenLLVMVisitor::visit_codegen_function(const ast::CodegenFunction& node) { + const auto& name = node.get_node_name(); + const auto& arguments = node.get_arguments(); + + // Create the entry basic block of the function/procedure and point the local named values table + // to the symbol table. + llvm::Function* func = module->getFunction(name); + ir_builder.create_block_and_set_insertion_point(func); + ir_builder.set_function(func); + + // When processing a function, it returns a value named in NMODL. Therefore, we + // first run RenameVisitor to rename it into ret_. This will aid in avoiding + // symbolic conflicts. + std::string return_var_name = "ret_" + name; + const auto& block = node.get_statement_block(); + visitor::RenameVisitor v(name, return_var_name); + block->accept(v); + + // Allocate parameters on the stack and add them to the symbol table. + ir_builder.allocate_function_arguments(func, arguments); + + // Process function or procedure body. If the function is a compute kernel, enable + // vectorization. If so, the return statement is handled in a separate visitor. + if (vector_width > 1 && is_kernel_function(name)) { + ir_builder.generate_vector_ir(); + block->accept(*this); + ir_builder.generate_scalar_ir(); + } else { + block->accept(*this); + } + + // If function is a compute kernel, add a void terminator explicitly, since there is no + // `CodegenReturnVar` node. Also, set the necessary attributes. + if (is_kernel_function(name)) { + ir_builder.set_kernel_attributes(); + ir_builder.create_return(); + } + + // Clear local values stack and remove the pointer to the local symbol table. + ir_builder.clear_function(); +} + +void CodegenLLVMVisitor::visit_codegen_return_statement(const ast::CodegenReturnStatement& node) { + if (!node.get_statement()->is_name()) + throw std::runtime_error("Error: CodegenReturnStatement must contain a name node\n"); + + std::string ret = "ret_" + ir_builder.get_current_function_name(); + llvm::Value* ret_value = ir_builder.create_load(ret); + ir_builder.create_return(ret_value); +} + +void CodegenLLVMVisitor::visit_codegen_var_list_statement( + const ast::CodegenVarListStatement& node) { + llvm::Type* scalar_type = get_codegen_var_type(*node.get_var_type()); + for (const auto& variable: node.get_variables()) { + const auto& identifier = variable->get_name(); + std::string name = variable->get_node_name(); + + // Local variable can be a scalar (Node AST class) or an array (IndexedName AST class). For + // each case, create memory allocations. + if (identifier->is_indexed_name()) { + const auto& indexed_name = std::dynamic_pointer_cast(identifier); + int length = get_num_elements(*indexed_name); + ir_builder.create_array_alloca(name, scalar_type, length); + } else if (identifier->is_name()) { + ir_builder.create_scalar_or_vector_alloca(name, scalar_type); + } else { + throw std::runtime_error("Error: unsupported local variable type\n"); + } + } +} + +void CodegenLLVMVisitor::visit_double(const ast::Double& node) { + ir_builder.create_fp_constant(node.get_value()); +} + +void CodegenLLVMVisitor::visit_function_block(const ast::FunctionBlock& node) { + // do nothing. \todo: remove old function blocks from ast. +} + +void CodegenLLVMVisitor::visit_function_call(const ast::FunctionCall& node) { + const auto& name = node.get_node_name(); + llvm::Function* func = module->getFunction(name); + if (func) { + create_function_call(func, name, node.get_arguments()); + } else { + // If generating scalable vectorized IR, process the call to `vscale` separately. + if (name == "llvm.vscale.i32" && scalable) { + ir_builder.create_vscale_call(*module); + return; + } + + auto symbol = sym_tab->lookup(name); + if (symbol && symbol->has_any_property(symtab::syminfo::NmodlType::extern_method)) { + create_external_function_call(name, node.get_arguments()); + } else { + throw std::runtime_error("Error: unknown function name: " + name + "\n"); + } + } +} + +void CodegenLLVMVisitor::visit_if_statement(const ast::IfStatement& node) { + // If vectorizing the compute kernel with control flow, process it separately. + if (vector_width > 1 && ir_builder.vectorizing()) { + create_vectorized_control_flow_block(node); + return; + } + + // Get the current and the next blocks within the function. + llvm::BasicBlock* curr_block = ir_builder.get_current_block(); + llvm::BasicBlock* next = curr_block->getNextNode(); + llvm::Function* func = curr_block->getParent(); + + // Add a true block and a merge block where the control flow merges. + llvm::BasicBlock* true_block = llvm::BasicBlock::Create(*context, /*Name=*/"", func, next); + llvm::BasicBlock* merge_block = llvm::BasicBlock::Create(*context, /*Name=*/"", func, next); + + // Add condition to the current block. + llvm::Value* cond = accept_and_get(node.get_condition()); + + // Process the true block. + ir_builder.set_insertion_point(true_block); + node.get_statement_block()->accept(*this); + ir_builder.create_br(merge_block); + + // Save the merge block and proceed with codegen for `else if` statements. + llvm::BasicBlock* exit = merge_block; + for (const auto& else_if: node.get_elseifs()) { + // Link the current block to the true and else blocks. + llvm::BasicBlock* else_block = + llvm::BasicBlock::Create(*context, /*Name=*/"", func, merge_block); + ir_builder.set_insertion_point(curr_block); + ir_builder.create_cond_br(cond, true_block, else_block); + + // Process else block. + ir_builder.set_insertion_point(else_block); + cond = accept_and_get(else_if->get_condition()); + + // Reassign true and merge blocks respectively. Note that the new merge block has to be + // connected to the old merge block (tmp). + true_block = llvm::BasicBlock::Create(*context, /*Name=*/"", func, merge_block); + llvm::BasicBlock* tmp = merge_block; + merge_block = llvm::BasicBlock::Create(*context, /*Name=*/"", func, merge_block); + ir_builder.set_insertion_point(merge_block); + ir_builder.create_br(tmp); + + // Process true block. + ir_builder.set_insertion_point(true_block); + else_if->get_statement_block()->accept(*this); + ir_builder.create_br(merge_block); + curr_block = else_block; + } + + // Finally, generate code for `else` statement if it exists. + const auto& elses = node.get_elses(); + llvm::BasicBlock* else_block; + if (elses) { + else_block = llvm::BasicBlock::Create(*context, /*Name=*/"", func, merge_block); + ir_builder.set_insertion_point(else_block); + elses->get_statement_block()->accept(*this); + ir_builder.create_br(merge_block); + } else { + else_block = merge_block; + } + ir_builder.set_insertion_point(curr_block); + ir_builder.create_cond_br(cond, true_block, else_block); + ir_builder.set_insertion_point(exit); +} + +void CodegenLLVMVisitor::visit_integer(const ast::Integer& node) { + ir_builder.create_i32_constant(node.get_value()); +} + +void CodegenLLVMVisitor::visit_program(const ast::Program& node) { + // Before generating LLVM: + // - convert function and procedure blocks into CodegenFunctions + // - gather information about AST. For now, information about functions + // and procedures is used only. + CodegenLLVMHelperVisitor v{vector_width, scalable}; + const auto& functions = v.get_codegen_functions(node); + instance_var_helper = v.get_instance_var_helper(); + sym_tab = node.get_symbol_table(); + std::string kernel_id = v.get_kernel_id(); + + // Initialize the builder for this NMODL program. + ir_builder.initialize(*sym_tab, kernel_id); + + // Create compile unit if adding debug information to the module. + if (add_debug_information) { + debug_builder.create_compile_unit(*module, module->getModuleIdentifier(), output_dir); + } + + // For every function, generate its declaration. Thus, we can look up + // `llvm::Function` in the symbol table in the module. + for (const auto& func: functions) { + create_function_declaration(*func); + } + + // Proceed with code generation. Right now, we do not do + // node.visit_children(*this); + // The reason is that the node may contain AST nodes for which the visitor functions have been + // defined. In our implementation we assume that the code generation is happening within the + // function scope. To avoid generating code outside of functions, visit only them for now. + // \todo: Handle what is mentioned here. + for (const auto& func: functions) { + visit_codegen_function(*func); + } + + // Finalize the debug information. + if (add_debug_information) { + debug_builder.finalize(); + } + + // Verify the generated LLVM IR module. + std::string error; + llvm::raw_string_ostream ostream(error); + if (verifyModule(*module, &ostream)) { + throw std::runtime_error("Error: incorrect IR has been generated!\n" + ostream.str()); + } + + if (opt_passes) { + logger->info("Running LLVM optimisation passes"); + run_ir_opt_passes(); + } + + // Optionally, replace LLVM math intrinsics with vector library calls. + if (vector_width > 1) { +#if LLVM_VERSION_MAJOR < 13 + logger->warn( + "This version of LLVM does not support replacement of LLVM intrinsics with vector " + "library calls"); +#else + // First, get the target library information and add vectorizable functions for the + // specified vector library. + llvm::Triple triple(llvm::sys::getDefaultTargetTriple()); + llvm::TargetLibraryInfoImpl target_lib_info = llvm::TargetLibraryInfoImpl(triple); + add_vectorizable_functions_from_vec_lib(target_lib_info, triple); + + // Run passes that replace math intrinsics. + codegen_pm.add(new llvm::TargetLibraryInfoWrapperPass(target_lib_info)); + codegen_pm.add(new llvm::ReplaceWithVeclibLegacy); + codegen_pm.doInitialization(); + for (auto& function: module->getFunctionList()) { + if (!function.isDeclaration()) + codegen_pm.run(function); + } + codegen_pm.doFinalization(); +#endif + } + + // If the output directory is specified, save the IR to .ll file. + // \todo: Consider saving the generated LLVM IR to bytecode (.bc) file instead. + if (output_dir != ".") { + std::error_code error_code; + std::unique_ptr out = std::make_unique( + output_dir + "/" + mod_filename + ".ll", error_code, llvm::sys::fs::OF_Text); + if (error_code) + throw std::runtime_error("Error: " + error_code.message()); + + std::unique_ptr annotator; + module->print(out->os(), annotator.get()); + out->keep(); + } + + logger->debug("Dumping generated IR...\n" + dump_module()); +} + +void CodegenLLVMVisitor::visit_procedure_block(const ast::ProcedureBlock& node) { + // do nothing. \todo: remove old procedures from ast. +} + +void CodegenLLVMVisitor::visit_unary_expression(const ast::UnaryExpression& node) { + ast::UnaryOp op = node.get_op().get_value(); + llvm::Value* value = accept_and_get(node.get_expression()); + ir_builder.create_unary_op(value, op); +} + +void CodegenLLVMVisitor::visit_var_name(const ast::VarName& node) { + llvm::Value* value = read_variable(node); + ir_builder.maybe_replicate_value(value); +} + +void CodegenLLVMVisitor::visit_while_statement(const ast::WhileStatement& node) { + // Get the current and the next blocks within the function. + llvm::BasicBlock* curr_block = ir_builder.get_current_block(); + llvm::BasicBlock* next = curr_block->getNextNode(); + llvm::Function* func = curr_block->getParent(); + + // Add a header and the body blocks. + llvm::BasicBlock* header = llvm::BasicBlock::Create(*context, /*Name=*/"", func, next); + llvm::BasicBlock* body = llvm::BasicBlock::Create(*context, /*Name=*/"", func, next); + llvm::BasicBlock* exit = llvm::BasicBlock::Create(*context, /*Name=*/"", func, next); + + ir_builder.create_br_and_set_insertion_point(header); + + + // Generate code for condition and create branch to the body block. + llvm::Value* condition = accept_and_get(node.get_condition()); + ir_builder.create_cond_br(condition, body, exit); + + ir_builder.set_insertion_point(body); + node.get_statement_block()->accept(*this); + ir_builder.create_br(header); + + ir_builder.set_insertion_point(exit); +} + +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/llvm/codegen_llvm_visitor.hpp b/src/codegen/llvm/codegen_llvm_visitor.hpp new file mode 100644 index 0000000000..c5a593cdeb --- /dev/null +++ b/src/codegen/llvm/codegen_llvm_visitor.hpp @@ -0,0 +1,248 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#pragma once + +/** + * \dir + * \brief LLVM based code generation backend implementation for CoreNEURON + * + * \file + * \brief \copybrief nmodl::codegen::CodegenLLVMVisitor + */ + +#include +#include + +#include "codegen/llvm/codegen_llvm_helper_visitor.hpp" +#include "codegen/llvm/llvm_debug_builder.hpp" +#include "codegen/llvm/llvm_ir_builder.hpp" +#include "symtab/symbol_table.hpp" +#include "utils/logger.hpp" +#include "visitors/ast_visitor.hpp" + +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" + +namespace nmodl { +namespace codegen { + +/** + * @defgroup llvm LLVM Based Code Generation Implementation + * @brief Implementations of LLVM based code generation + * + * @defgroup llvm_backends LLVM Codegen Backend + * @ingroup llvm + * @brief Code generation backends for NMODL AST to LLVM IR + * @{ + */ + + +/** + * \class CodegenLLVMVisitor + * \brief %Visitor for transforming NMODL AST to LLVM IR + */ +class CodegenLLVMVisitor: public visitor::ConstAstVisitor { + /// Name of mod file (without .mod suffix). + std::string mod_filename; + + /// Output directory for code generation. + std::string output_dir; + + private: + /// Underlying LLVM context. + std::unique_ptr context = std::make_unique(); + + /// Underlying LLVM module. + std::unique_ptr module = std::make_unique(mod_filename, *context); + + /// LLVM IR builder. + IRBuilder ir_builder; + + /// Debug information builder. + DebugBuilder debug_builder; + + /// Add debug information to the module. + bool add_debug_information; + + /// Pointer to AST symbol table. + symtab::SymbolTable* sym_tab; + + /// Instance variable helper. + InstanceVarHelper instance_var_helper; + + /// Run optimisation passes if true. + bool opt_passes; + + /// Pass manager for optimisation passes that are run on IR and are not related to target. + llvm::legacy::FunctionPassManager opt_pm; + + /// Pass manager for optimisation passes that are used for target code generation. + llvm::legacy::FunctionPassManager codegen_pm; + + /// Vector library used for math functions. + std::string vector_library; + + /// Explicit vectorisation width. + int vector_width; + + /// Generate scalable vectorized IR. + bool scalable; + + public: + CodegenLLVMVisitor(const std::string& mod_filename, + const std::string& output_dir, + bool opt_passes, + bool use_single_precision = false, + int vector_width = 1, + std::string vec_lib = "none", + bool add_debug_information = false, + std::vector fast_math_flags = {}, + bool scalable = false) + : mod_filename(mod_filename) + , output_dir(output_dir) + , opt_passes(opt_passes) + , vector_width(vector_width) + , vector_library(vec_lib) + , scalable(scalable) + , add_debug_information(add_debug_information) + , ir_builder(*context, use_single_precision, vector_width, fast_math_flags, scalable) + , debug_builder(*module) + , codegen_pm(module.get()) + , opt_pm(module.get()) {} + + /// Dumps the generated LLVM IR module to string. + std::string dump_module() const { + std::string str; + llvm::raw_string_ostream os(str); + os << *module; + os.flush(); + return str; + } + + /// Fills the container with the names of kernel functions from the MOD file. + void find_kernel_names(std::vector& container); + + /// Returns underlying module. + std::unique_ptr get_module() { + return std::move(module); + } + + /// Returns shared_ptr to generated ast::InstanceStruct. + std::shared_ptr get_instance_struct_ptr() { + return instance_var_helper.instance; + } + + /// Returns InstanceVarHelper for the given MOD file. + InstanceVarHelper get_instance_var_helper() { + return instance_var_helper; + } + + /// Returns vector width + int get_vector_width() const { + return vector_width; + } + + // Visitors. + void visit_binary_expression(const ast::BinaryExpression& node) override; + void visit_boolean(const ast::Boolean& node) override; + void visit_codegen_atomic_statement(const ast::CodegenAtomicStatement& node) override; + void visit_codegen_for_statement(const ast::CodegenForStatement& node) override; + void visit_codegen_function(const ast::CodegenFunction& node) override; + void visit_codegen_return_statement(const ast::CodegenReturnStatement& node) override; + void visit_codegen_var_list_statement(const ast::CodegenVarListStatement& node) override; + void visit_double(const ast::Double& node) override; + void visit_function_block(const ast::FunctionBlock& node) override; + void visit_function_call(const ast::FunctionCall& node) override; + void visit_if_statement(const ast::IfStatement& node) override; + void visit_integer(const ast::Integer& node) override; + void visit_procedure_block(const ast::ProcedureBlock& node) override; + void visit_program(const ast::Program& node) override; + void visit_statement_block(const ast::StatementBlock& node) override; + void visit_unary_expression(const ast::UnaryExpression& node) override; + void visit_var_name(const ast::VarName& node) override; + void visit_while_statement(const ast::WhileStatement& node) override; + + /// Wraps all kernel function calls into wrapper functions that use `void*` to pass the data to + /// the kernel. + void wrap_kernel_functions(); + + private: +#if LLVM_VERSION_MAJOR >= 13 + /// Populates target library info with the vector library definitions. + void add_vectorizable_functions_from_vec_lib(llvm::TargetLibraryInfoImpl& tli, + llvm::Triple& triple); +#endif + + /// Accepts the given AST node and returns the processed value. + llvm::Value* accept_and_get(const std::shared_ptr& node); + + /// Creates a call to an external function (e.g pow, exp, etc.) + void create_external_function_call(const std::string& name, + const ast::ExpressionVector& arguments); + + /// Creates a call to NMODL function or procedure in the same MOD file. + void create_function_call(llvm::Function* func, + const std::string& name, + const ast::ExpressionVector& arguments); + + /// Fills values vector with processed NMODL function call arguments. + void create_function_call_arguments(const ast::ExpressionVector& arguments, + ValueVector& arg_values); + + /// Creates the function declaration for the given AST node. + void create_function_declaration(const ast::CodegenFunction& node); + + /// Creates a call to `printf` function. + void create_printf_call(const ast::ExpressionVector& arguments); + + /// Creates a vectorized version of the LLVM IR for the simple control flow statement. + void create_vectorized_control_flow_block(const ast::IfStatement& node); + + /// Returns LLVM type for the given CodegenVarType AST node. + llvm::Type* get_codegen_var_type(const ast::CodegenVarType& node); + + /// Returns the index value from the IndexedName AST node. + llvm::Value* get_index(const ast::IndexedName& node); + + /// Returns an instance struct type. + llvm::Type* get_instance_struct_type(); + + /// Returns the number of elements in the array specified by the IndexedName AST node. + int get_num_elements(const ast::IndexedName& node); + + /// Returns whether the function is an NMODL compute kernel. + bool is_kernel_function(const std::string& function_name); + + /// If the value to store is specified, writes it to the instance. Otherwise, returns the + /// instance variable. + llvm::Value* read_from_or_write_to_instance(const ast::CodegenInstanceVar& node, + llvm::Value* maybe_value_to_store = nullptr); + + /// Reads the given variable and returns the processed value. + llvm::Value* read_variable(const ast::VarName& node); + + + /// Run multiple LLVM optimisation passes on generated IR. + /// TODO: this can be moved to a dedicated file or deprecated. + void run_ir_opt_passes(); + + //// Writes the value to the given variable. + void write_to_variable(const ast::VarName& node, llvm::Value* value); +}; + +/** \} */ // end of llvm_backends + +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/llvm/llvm_debug_builder.cpp b/src/codegen/llvm/llvm_debug_builder.cpp new file mode 100644 index 0000000000..5682a6e904 --- /dev/null +++ b/src/codegen/llvm/llvm_debug_builder.cpp @@ -0,0 +1,63 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include "codegen/llvm/llvm_debug_builder.hpp" + +namespace nmodl { +namespace codegen { + + +static constexpr const char debug_version_key[] = "Debug Version"; + + +void DebugBuilder::add_function_debug_info(llvm::Function* function, Location* loc) { + // Create the function debug type (subroutine type). We are not interested in parameters and + // types, and therefore passing llvm::None as argument suffices for now. + llvm::DISubroutineType* subroutine_type = di_builder.createSubroutineType( + di_builder.getOrCreateTypeArray(llvm::None)); + llvm::DISubprogram::DISPFlags sp_flags = llvm::DISubprogram::SPFlagDefinition | + llvm::DISubprogram::SPFlagOptimized; + // If there is no location associated with the function, just use 0. + int line = loc ? loc->line : 0; + llvm::DISubprogram* program = di_builder.createFunction(compile_unit, + function->getName(), + function->getName(), + file, + line, + subroutine_type, + line, + llvm::DINode::FlagZero, + sp_flags); + function->setSubprogram(program); + di_builder.finalizeSubprogram(program); +} + +void DebugBuilder::create_compile_unit(llvm::Module& module, + const std::string& debug_filename, + const std::string& debug_output_dir) { + // Create the debug file and compile unit for the module. + file = di_builder.createFile(debug_filename, debug_output_dir); + compile_unit = di_builder.createCompileUnit(llvm::dwarf::DW_LANG_C, + file, + /*Producer=*/"NMODL-LLVM", + /*isOptimized=*/false, + /*Flags=*/"", + /*RV=*/0); + + // Add a flag to the module to specify that it has debug information. + if (!module.getModuleFlag(debug_version_key)) { + module.addModuleFlag(llvm::Module::Warning, + debug_version_key, + llvm::DEBUG_METADATA_VERSION); + } +} + +void DebugBuilder::finalize() { + di_builder.finalize(); +} +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/llvm/llvm_debug_builder.hpp b/src/codegen/llvm/llvm_debug_builder.hpp new file mode 100644 index 0000000000..9322cd461a --- /dev/null +++ b/src/codegen/llvm/llvm_debug_builder.hpp @@ -0,0 +1,70 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#pragma once + +#include + +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" + +namespace nmodl { +namespace codegen { + +/// A struct to store AST location information. +/// \todo Currently, not all AST nodes have location information. Moreover, +/// some may not have it as they were artificially introduced (e.g. +/// CodegenForStatement). This simple wrapper suffices for now, but in future +/// we may want to handle this properly. +struct Location { + /// Line in the file. + int line; + + /// Column in the file. + int column; +}; + + +/** + * \class DebugBuilder + * \brief A helper class to create debug information for LLVM IR module. + * \todo Only function debug information is supported. + */ +class DebugBuilder { + private: + /// Debug information builder. + llvm::DIBuilder di_builder; + + /// LLVM context. + llvm::LLVMContext& context; + + /// Debug compile unit for the module. + llvm::DICompileUnit* compile_unit = nullptr; + + /// Debug file pointer. + llvm::DIFile* file = nullptr; + + public: + DebugBuilder(llvm::Module& module) + : di_builder(module) + , context(module.getContext()) {} + + /// Adds function debug information with an optional location. + void add_function_debug_info(llvm::Function* function, Location* loc = nullptr); + + /// Creates the compile unit for and sets debug flags for the module. + void create_compile_unit(llvm::Module& module, + const std::string& debug_filename, + const std::string& debug_output_dir); + + /// Finalizes the debug information. + void finalize(); +}; +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/llvm/llvm_ir_builder.cpp b/src/codegen/llvm/llvm_ir_builder.cpp new file mode 100644 index 0000000000..e28b3bfd0d --- /dev/null +++ b/src/codegen/llvm/llvm_ir_builder.cpp @@ -0,0 +1,601 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include "codegen/llvm/llvm_ir_builder.hpp" +#include "ast/all.hpp" + +#include "llvm/ADT/StringSwitch.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/ValueSymbolTable.h" + +namespace nmodl { +namespace codegen { + + +/****************************************************************************************/ +/* LLVM type utilities */ +/****************************************************************************************/ + +llvm::Type* IRBuilder::get_boolean_type() { + return llvm::Type::getInt1Ty(builder.getContext()); +} + +llvm::Type* IRBuilder::get_i8_ptr_type() { + return llvm::Type::getInt8PtrTy(builder.getContext()); +} + +llvm::Type* IRBuilder::get_i32_type() { + return llvm::Type::getInt32Ty(builder.getContext()); +} + +llvm::Type* IRBuilder::get_i32_ptr_type() { + return llvm::Type::getInt32PtrTy(builder.getContext()); +} + +llvm::Type* IRBuilder::get_i64_type() { + return llvm::Type::getInt64Ty(builder.getContext()); +} + +llvm::Type* IRBuilder::get_fp_type() { + if (fp_precision == single_precision) + return llvm::Type::getFloatTy(builder.getContext()); + return llvm::Type::getDoubleTy(builder.getContext()); +} + +llvm::Type* IRBuilder::get_fp_ptr_type() { + if (fp_precision == single_precision) + return llvm::Type::getFloatPtrTy(builder.getContext()); + return llvm::Type::getDoublePtrTy(builder.getContext()); +} + +llvm::Type* IRBuilder::get_void_type() { + return llvm::Type::getVoidTy(builder.getContext()); +} + +llvm::Type* IRBuilder::get_struct_ptr_type(const std::string& struct_type_name, + TypeVector& member_types) { + llvm::StructType* llvm_struct_type = llvm::StructType::create(builder.getContext(), + struct_type_name); + llvm_struct_type->setBody(member_types); + return llvm::PointerType::get(llvm_struct_type, /*AddressSpace=*/0); +} + +llvm::Type* IRBuilder::get_vector_type(llvm::Type* element_type, unsigned min_num_elements) { + if (scalable) + return llvm::ScalableVectorType::get(element_type, min_num_elements); + return llvm::FixedVectorType::get(element_type, min_num_elements); +} + +/****************************************************************************************/ +/* LLVM value utilities */ +/****************************************************************************************/ + +llvm::Value* IRBuilder::lookup_value(const std::string& value_name) { + auto value = current_function->getValueSymbolTable()->lookup(value_name); + if (!value) + throw std::runtime_error("Error: variable " + value_name + " is not in the scope\n"); + return value; +} + +llvm::Value* IRBuilder::pop_last_value() { + // Check if the stack is empty. + if (value_stack.empty()) + throw std::runtime_error("Error: popping a value from the empty stack\n"); + + // Return the last added value and delete it from the stack. + llvm::Value* last = value_stack.back(); + value_stack.pop_back(); + return last; +} + +/****************************************************************************************/ +/* LLVM constants utilities */ +/****************************************************************************************/ + +void IRBuilder::create_boolean_constant(int value) { + if (vector_width > 1 && vectorize) { + value_stack.push_back(get_vector_constant(get_boolean_type(), value)); + } else { + value_stack.push_back(get_scalar_constant(get_boolean_type(), value)); + } +} + +void IRBuilder::create_fp_constant(const std::string& value) { + if (vector_width > 1 && vectorize) { + value_stack.push_back(get_vector_constant(get_fp_type(), value)); + } else { + value_stack.push_back(get_scalar_constant(get_fp_type(), value)); + } +} + +llvm::Value* IRBuilder::create_global_string(const ast::String& node) { + return builder.CreateGlobalStringPtr(node.get_value()); +} + +void IRBuilder::create_i32_constant(int value) { + if (vector_width > 1 && vectorize) { + value_stack.push_back(get_vector_constant(get_i32_type(), value)); + } else { + value_stack.push_back(get_scalar_constant(get_i32_type(), value)); + } +} + +template +llvm::Value* IRBuilder::get_scalar_constant(llvm::Type* type, V value) { + return C::get(type, value); +} + +template +llvm::Value* IRBuilder::get_vector_constant(llvm::Type* type, V value) { + // Handle scalable vector constant differently. + if (scalable) { + llvm::Type* vector_type = llvm::ScalableVectorType::get(type, vector_width); + + // First, create a scalable vector with 0th element set to the value. + llvm::Value* to_insert = get_scalar_constant(type, value); + llvm::Value* zero = get_scalar_constant(get_i32_type(), 0); + llvm::Value* lhs = + builder.CreateInsertElement(llvm::UndefValue::get(vector_type), to_insert, zero); + + // Then, use `shufflevector` with zeroinitializer to select only 0th element. + llvm::Value* select = llvm::Constant::getNullValue(vector_type); + return builder.CreateShuffleVector(lhs, llvm::UndefValue::get(vector_type), select); + } + + // Otherwise, create a fixed vector constant. + ConstantVector constants; + for (unsigned i = 0; i < vector_width; ++i) { + const auto& element = C::get(type, value); + constants.push_back(element); + } + return llvm::ConstantVector::get(constants); +} + +/****************************************************************************************/ +/* LLVM function utilities */ +/****************************************************************************************/ + +void IRBuilder::allocate_function_arguments(llvm::Function* function, + const ast::CodegenVarWithTypeVector& nmodl_arguments) { + unsigned i = 0; + for (auto& arg: function->args()) { + std::string arg_name = nmodl_arguments[i++].get()->get_node_name(); + llvm::Type* arg_type = arg.getType(); + llvm::Value* alloca = create_alloca(arg_name, arg_type); + arg.setName(arg_name); + builder.CreateStore(&arg, alloca); + } +} + +std::string IRBuilder::get_current_function_name() { + return current_function->getName().str(); +} + +void IRBuilder::create_function_call(llvm::Function* callee, + ValueVector& arguments, + bool use_result) { + llvm::Value* call_instruction = builder.CreateCall(callee, arguments); + if (use_result) + value_stack.push_back(call_instruction); +} + +void IRBuilder::create_intrinsic(const std::string& name, + ValueVector& argument_values, + TypeVector& argument_types) { + // Process 'pow' call separately. + if (name == "pow") { + llvm::Value* pow_intrinsic = builder.CreateIntrinsic(llvm::Intrinsic::pow, + {argument_types.front()}, + argument_values); + value_stack.push_back(pow_intrinsic); + return; + } + + // Create other intrinsics. + unsigned intrinsic_id = llvm::StringSwitch(name) + .Case("ceil", llvm::Intrinsic::ceil) + .Case("cos", llvm::Intrinsic::cos) + .Case("exp", llvm::Intrinsic::exp) + .Case("fabs", llvm::Intrinsic::fabs) + .Case("floor", llvm::Intrinsic::floor) + .Case("log", llvm::Intrinsic::log) + .Case("log10", llvm::Intrinsic::log10) + .Case("sin", llvm::Intrinsic::sin) + .Case("sqrt", llvm::Intrinsic::sqrt) + .Default(llvm::Intrinsic::not_intrinsic); + if (intrinsic_id) { + llvm::Value* intrinsic = + builder.CreateIntrinsic(intrinsic_id, argument_types, argument_values); + value_stack.push_back(intrinsic); + } else { + throw std::runtime_error("Error: calls to " + name + " are not valid or not supported\n"); + } +} + +void IRBuilder::create_vscale_call(llvm::Module& module) { + llvm::Function* vscale_function = + llvm::Intrinsic::getDeclaration(&module, llvm::Intrinsic::vscale, get_i32_type()); + llvm::Value* vscale = builder.CreateCall(vscale_function); + value_stack.push_back(vscale); +} + +void IRBuilder::set_kernel_attributes() { + // By convention, the compute kernel does not free memory and does not throw exceptions. + current_function->setDoesNotFreeMemory(); + current_function->setDoesNotThrow(); + + // We also want to specify that the pointers that instance struct holds, do not alias. In order + // to do that, we add a `noalias` attribute to the argument. As per Clang's specification: + // > The `noalias` attribute indicates that the only memory accesses inside function are loads + // > and stores from objects pointed to by its pointer-typed arguments, with arbitrary + // > offsets. + current_function->addParamAttr(0, llvm::Attribute::NoAlias); + + // Finally, specify that the struct pointer does not capture and is read-only. + current_function->addParamAttr(0, llvm::Attribute::NoCapture); + current_function->addParamAttr(0, llvm::Attribute::ReadOnly); +} + +/****************************************************************************************/ +/* LLVM metadata utilities */ +/****************************************************************************************/ + +void IRBuilder::set_loop_metadata(llvm::BranchInst* branch) { + llvm::LLVMContext& context = builder.getContext(); + MetadataVector loop_metadata; + + // Add nullptr to reserve the first place for loop's metadata self-reference. + loop_metadata.push_back(nullptr); + + // If `vector_width` is 1, explicitly disable vectorization for benchmarking purposes. + if (vector_width == 1) { + llvm::MDString* name = llvm::MDString::get(context, "llvm.loop.vectorize.enable"); + llvm::Value* false_value = llvm::ConstantInt::get(get_boolean_type(), 0); + llvm::ValueAsMetadata* value = llvm::ValueAsMetadata::get(false_value); + loop_metadata.push_back(llvm::MDNode::get(context, {name, value})); + } + + // No metadata to add. + if (loop_metadata.size() <= 1) + return; + + // Add loop's metadata self-reference and attach it to the branch. + llvm::MDNode* metadata = llvm::MDNode::get(context, loop_metadata); + metadata->replaceOperandWith(0, metadata); + branch->setMetadata(llvm::LLVMContext::MD_loop, metadata); +} + +/****************************************************************************************/ +/* LLVM instruction utilities */ +/****************************************************************************************/ + +llvm::Value* IRBuilder::create_alloca(const std::string& name, llvm::Type* type) { + // If insertion point for `alloca` instructions is not set, then create the instruction in the + // entry block and set it to be the insertion point. + if (!alloca_ip) { + // Get the entry block and insert the `alloca` instruction there. + llvm::BasicBlock* current_block = builder.GetInsertBlock(); + llvm::BasicBlock& entry_block = current_block->getParent()->getEntryBlock(); + builder.SetInsertPoint(&entry_block); + llvm::Value* alloca = builder.CreateAlloca(type, /*ArraySize=*/nullptr, name); + + // Set the `alloca` instruction insertion point and restore the insertion point for the next + // set of instructions. + alloca_ip = llvm::cast(alloca); + builder.SetInsertPoint(current_block); + return alloca; + } + + // Create `alloca` instruction. + llvm::BasicBlock* alloca_block = alloca_ip->getParent(); + const auto& data_layout = alloca_block->getModule()->getDataLayout(); + auto* alloca = new llvm::AllocaInst(type, + data_layout.getAllocaAddrSpace(), + /*ArraySize=*/nullptr, + data_layout.getPrefTypeAlign(type), + name); + + // Insert `alloca` at the specified insertion point and reset it for the next instructions. + alloca_block->getInstList().insertAfter(alloca_ip->getIterator(), alloca); + alloca_ip = alloca; + return alloca; +} + +void IRBuilder::create_array_alloca(const std::string& name, + llvm::Type* element_type, + int num_elements) { + llvm::Type* array_type = llvm::ArrayType::get(element_type, num_elements); + create_alloca(name, array_type); +} + +void IRBuilder::create_binary_op(llvm::Value* lhs, llvm::Value* rhs, ast::BinaryOp op) { + // Check that both lhs and rhs have the same types. + if (lhs->getType() != rhs->getType()) + throw std::runtime_error( + "Error: lhs and rhs of the binary operator have different types\n"); + + llvm::Value* result; + switch (op) { +#define DISPATCH(binary_op, fp_instruction, integer_instruction) \ + case binary_op: \ + if (lhs->getType()->isIntOrIntVectorTy()) \ + result = integer_instruction(lhs, rhs); \ + else \ + result = fp_instruction(lhs, rhs); \ + break; + + // Arithmetic instructions. + DISPATCH(ast::BinaryOp::BOP_ADDITION, builder.CreateFAdd, builder.CreateAdd); + DISPATCH(ast::BinaryOp::BOP_DIVISION, builder.CreateFDiv, builder.CreateSDiv); + DISPATCH(ast::BinaryOp::BOP_MULTIPLICATION, builder.CreateFMul, builder.CreateMul); + DISPATCH(ast::BinaryOp::BOP_SUBTRACTION, builder.CreateFSub, builder.CreateSub); + + // Comparison instructions. + DISPATCH(ast::BinaryOp::BOP_EXACT_EQUAL, builder.CreateFCmpOEQ, builder.CreateICmpEQ); + DISPATCH(ast::BinaryOp::BOP_GREATER, builder.CreateFCmpOGT, builder.CreateICmpSGT); + DISPATCH(ast::BinaryOp::BOP_GREATER_EQUAL, builder.CreateFCmpOGE, builder.CreateICmpSGE); + DISPATCH(ast::BinaryOp::BOP_LESS, builder.CreateFCmpOLT, builder.CreateICmpSLT); + DISPATCH(ast::BinaryOp::BOP_LESS_EQUAL, builder.CreateFCmpOLE, builder.CreateICmpSLE); + DISPATCH(ast::BinaryOp::BOP_NOT_EQUAL, builder.CreateFCmpONE, builder.CreateICmpNE); + +#undef DISPATCH + + // Separately replace ^ with the `pow` intrinsic. + case ast::BinaryOp::BOP_POWER: + result = builder.CreateIntrinsic(llvm::Intrinsic::pow, {lhs->getType()}, {lhs, rhs}); + break; + + // Logical instructions. + case ast::BinaryOp::BOP_AND: + result = builder.CreateAnd(lhs, rhs); + break; + case ast::BinaryOp::BOP_OR: + result = builder.CreateOr(lhs, rhs); + break; + + default: + throw std::runtime_error("Error: unsupported binary operator\n"); + } + value_stack.push_back(result); +} + +llvm::Value* IRBuilder::create_bitcast(llvm::Value* value, llvm::Type* dst_type) { + return builder.CreateBitCast(value, dst_type); +} + +llvm::Value* IRBuilder::create_inbounds_gep(const std::string& var_name, llvm::Value* index) { + llvm::Value* variable_ptr = lookup_value(var_name); + + // Since we index through the pointer, we need an extra 0 index in the indices list for GEP. + ValueVector indices{llvm::ConstantInt::get(get_i64_type(), 0), index}; + return builder.CreateInBoundsGEP(variable_ptr, indices); +} + +llvm::Value* IRBuilder::create_inbounds_gep(llvm::Value* variable, llvm::Value* index) { + return builder.CreateInBoundsGEP(variable, {index}); +} + +llvm::Value* IRBuilder::create_index(llvm::Value* value) { + // Check if index is a double. While it is possible to use casting from double to integer + // values, we choose not to support these cases. + llvm::Type* value_type = value->getType(); + if (!value_type->isIntOrIntVectorTy()) + throw std::runtime_error("Error: only integer indexing is supported\n"); + + // Conventionally, in LLVM array indices are 64 bit. + llvm::Type* i64_type = get_i64_type(); + if (auto index_type = llvm::dyn_cast(value_type)) { + if (index_type->getBitWidth() == i64_type->getIntegerBitWidth()) + return value; + return builder.CreateSExtOrTrunc(value, i64_type); + } + + const auto& vector_type = llvm::cast(value_type); + const auto& element_type = llvm::cast(vector_type->getElementType()); + if (element_type->getBitWidth() == i64_type->getIntegerBitWidth()) + return value; + return builder.CreateSExtOrTrunc(value, get_vector_type(i64_type, vector_width)); +} + +llvm::Value* IRBuilder::create_load(const std::string& name, bool masked) { + llvm::Value* ptr = lookup_value(name); + + // Check if the generated IR is vectorized and masked. + if (masked) { + return builder.CreateMaskedLoad(ptr, llvm::Align(), mask); + } + llvm::Type* loaded_type = ptr->getType()->getPointerElementType(); + llvm::Value* loaded = builder.CreateLoad(loaded_type, ptr); + value_stack.push_back(loaded); + return loaded; +} + +llvm::Value* IRBuilder::create_load(llvm::Value* ptr, bool masked) { + // Check if the generated IR is vectorized and masked. + if (masked) { + return builder.CreateMaskedLoad(ptr, llvm::Align(), mask); + } + llvm::Type* loaded_type = ptr->getType()->getPointerElementType(); + llvm::Value* loaded = builder.CreateLoad(loaded_type, ptr); + value_stack.push_back(loaded); + return loaded; +} + +llvm::Value* IRBuilder::create_load_from_array(const std::string& name, llvm::Value* index) { + llvm::Value* element_ptr = create_inbounds_gep(name, index); + return create_load(element_ptr); +} + +void IRBuilder::create_store(const std::string& name, llvm::Value* value, bool masked) { + llvm::Value* ptr = lookup_value(name); + + // Check if the generated IR is vectorized and masked. + if (masked) { + builder.CreateMaskedStore(value, ptr, llvm::Align(), mask); + return; + } + builder.CreateStore(value, ptr); +} + +void IRBuilder::create_store(llvm::Value* ptr, llvm::Value* value, bool masked) { + // Check if the generated IR is vectorized and masked. + if (masked) { + builder.CreateMaskedStore(value, ptr, llvm::Align(), mask); + return; + } + builder.CreateStore(value, ptr); +} + +void IRBuilder::create_store_to_array(const std::string& name, + llvm::Value* index, + llvm::Value* value) { + llvm::Value* element_ptr = create_inbounds_gep(name, index); + create_store(element_ptr, value); +} + +void IRBuilder::create_return(llvm::Value* return_value) { + if (return_value) + builder.CreateRet(return_value); + else + builder.CreateRetVoid(); +} + +void IRBuilder::create_scalar_or_vector_alloca(const std::string& name, + llvm::Type* element_or_scalar_type) { + // Even if generating vectorised code, some variables still need to be scalar. Particularly, the + // induction variable "id" and remainder loop variables (that start with "epilogue" prefix). + llvm::Type* type; + if (vector_width > 1 && vectorize && name != kernel_id && name.rfind("epilogue", 0)) { + type = get_vector_type(element_or_scalar_type, vector_width); + } else { + type = element_or_scalar_type; + } + create_alloca(name, type); +} + +void IRBuilder::create_unary_op(llvm::Value* value, ast::UnaryOp op) { + if (op == ast::UOP_NEGATION) { + value_stack.push_back(builder.CreateFNeg(value)); + } else if (op == ast::UOP_NOT) { + value_stack.push_back(builder.CreateNot(value)); + } else { + throw std::runtime_error("Error: unsupported unary operator\n"); + } +} + +llvm::Value* IRBuilder::get_struct_member_ptr(llvm::Value* struct_variable, int member_index) { + ValueVector indices; + indices.push_back(llvm::ConstantInt::get(get_i32_type(), 0)); + indices.push_back(llvm::ConstantInt::get(get_i32_type(), member_index)); + return builder.CreateInBoundsGEP(struct_variable, indices); +} + +void IRBuilder::invert_mask() { + if (!mask) + throw std::runtime_error("Error: mask is not set\n"); + + // Create the vector with all `true` values. + create_boolean_constant(1); + llvm::Value* one = pop_last_value(); + + mask = builder.CreateXor(mask, one); +} + +llvm::Value* IRBuilder::load_to_or_store_from_array(const std::string& id_name, + llvm::Value* id_value, + llvm::Value* array, + llvm::Value* maybe_value_to_store) { + // First, calculate the address of the element in the array. + llvm::Value* element_ptr = create_inbounds_gep(array, id_value); + + // Find out if the vector code is generated. + bool generating_vector_ir = vector_width > 1 && vectorize; + + // If the vector code is generated, we need to distinguish between two cases. If the array is + // indexed indirectly (i.e. not by an induction variable `kernel_id`), create a gather + // instruction. + if (id_name != kernel_id && generating_vector_ir) { + return maybe_value_to_store ? builder.CreateMaskedScatter(maybe_value_to_store, + element_ptr, + llvm::Align(), + mask) + : builder.CreateMaskedGather(element_ptr, llvm::Align(), mask); + } + + llvm::Value* ptr; + if (generating_vector_ir) { + // If direct indexing is used during the vectorization, we simply bitcast the scalar pointer + // to a vector pointer + llvm::Type* vector_type = + llvm::PointerType::get(get_vector_type(element_ptr->getType()->getPointerElementType(), + vector_width), + /*AddressSpace=*/0); + ptr = builder.CreateBitCast(element_ptr, vector_type); + } else { + // Otherwise, scalar code is generated and hence return the element pointer. + ptr = element_ptr; + } + + if (maybe_value_to_store) { + create_store(ptr, maybe_value_to_store, /*masked=*/mask && generating_vector_ir); + return nullptr; + } else { + return create_load(ptr, /*masked=*/mask && generating_vector_ir); + } +} + +void IRBuilder::maybe_replicate_value(llvm::Value* value) { + // If the value should not be vectorised, or it is already a vector, add it to the stack. + if (!vectorize || vector_width == 1 || value->getType()->isVectorTy()) { + value_stack.push_back(value); + } else { + // Otherwise, we generate vectorized code inside the loop, so replicate the value to form a + // vector. + llvm::Value* vector_value = builder.CreateVectorSplat(vector_width, value); + value_stack.push_back(vector_value); + } +} + + +/****************************************************************************************/ +/* LLVM block utilities */ +/****************************************************************************************/ + +llvm::BasicBlock* IRBuilder::create_block_and_set_insertion_point(llvm::Function* function, + llvm::BasicBlock* insert_before, + std::string name) { + llvm::BasicBlock* block = + llvm::BasicBlock::Create(builder.getContext(), name, function, insert_before); + builder.SetInsertPoint(block); + return block; +} + +void IRBuilder::create_br(llvm::BasicBlock* block) { + builder.CreateBr(block); +} + +void IRBuilder::create_br_and_set_insertion_point(llvm::BasicBlock* block) { + builder.CreateBr(block); + builder.SetInsertPoint(block); +} + +llvm::BranchInst* IRBuilder::create_cond_br(llvm::Value* condition, + llvm::BasicBlock* true_block, + llvm::BasicBlock* false_block) { + return builder.CreateCondBr(condition, true_block, false_block); +} + +llvm::BasicBlock* IRBuilder::get_current_block() { + return builder.GetInsertBlock(); +} + +void IRBuilder::set_insertion_point(llvm::BasicBlock* block) { + builder.SetInsertPoint(block); +} + +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/llvm/llvm_ir_builder.hpp b/src/codegen/llvm/llvm_ir_builder.hpp new file mode 100644 index 0000000000..7ba58f93a2 --- /dev/null +++ b/src/codegen/llvm/llvm_ir_builder.hpp @@ -0,0 +1,333 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#pragma once + +#include + +#include "codegen/llvm/codegen_llvm_helper_visitor.hpp" +#include "symtab/symbol_table.hpp" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" + +namespace nmodl { +namespace codegen { + +/// Floating point bit widths. +static constexpr const unsigned single_precision = 32; +static constexpr const unsigned double_precision = 64; + +/// Some typedefs. +using ConstantVector = std::vector; +using MetadataVector = std::vector; +using TypeVector = std::vector; +using ValueVector = std::vector; + +/** + * \class IRBuilder + * \brief A helper class to generate LLVM IR for NMODL AST. + */ +class IRBuilder { + private: + /// Underlying LLVM IR builder. + llvm::IRBuilder<> builder; + + /// Stack to hold visited and processed values. + ValueVector value_stack; + + /// Pointer to the current function for which the code is generated. + llvm::Function* current_function; + + /// Symbol table of the NMODL AST. + symtab::SymbolTable* symbol_table; + + /// Insertion point for `alloca` instructions. + llvm::Instruction* alloca_ip; + + /// Flag to indicate that the generated IR should be vectorized. + bool vectorize; + + /// Precision of the floating-point numbers (32 or 64 bit). + unsigned fp_precision; + + /// The vector width used for the vectorized code. + unsigned vector_width; + + /// Use scalable vector types. + bool scalable; + + /// Masked value used to predicate vector instructions. + llvm::Value* mask; + + /// The name of induction variable used in kernel loops. + std::string kernel_id; + + /// Fast math flags for floating-point IR instructions. + std::vector fast_math_flags; + + public: + IRBuilder(llvm::LLVMContext& context, + bool use_single_precision = false, + unsigned vector_width = 1, + std::vector fast_math_flags = {}, + bool scalable = false) + : builder(context) + , symbol_table(nullptr) + , current_function(nullptr) + , vectorize(false) + , alloca_ip(nullptr) + , scalable(scalable) + , fp_precision(use_single_precision ? single_precision : double_precision) + , vector_width(vector_width) + , mask(nullptr) + , kernel_id("") + , fast_math_flags(fast_math_flags) {} + + /// Transforms the fast math flags provided to the builder into LLVM's representation. + llvm::FastMathFlags transform_to_fmf(std::vector& flags) { + static const std::map set_flag = { + {"nnan", &llvm::FastMathFlags::setNoNaNs}, + {"ninf", &llvm::FastMathFlags::setNoInfs}, + {"nsz", &llvm::FastMathFlags::setNoSignedZeros}, + {"contract", &llvm::FastMathFlags::setAllowContract}, + {"afn", &llvm::FastMathFlags::setApproxFunc}, + {"reassoc", &llvm::FastMathFlags::setAllowReassoc}, + {"fast", &llvm::FastMathFlags::setFast}}; + llvm::FastMathFlags fmf; + for (const auto& flag: flags) { + (fmf.*(set_flag.at(flag)))(true); + } + return fmf; + } + + /// Initializes the builder with the symbol table and the kernel induction variable id. + void initialize(symtab::SymbolTable& symbol_table, std::string& kernel_id) { + if (!fast_math_flags.empty()) + builder.setFastMathFlags(transform_to_fmf(fast_math_flags)); + this->symbol_table = &symbol_table; + this->kernel_id = kernel_id; + } + + /// Explicitly sets the builder to produce scalar IR. + void generate_scalar_ir() { + vectorize = false; + } + + /// Indicates whether the builder generates vectorized IR. + bool vectorizing() { + return vectorize; + } + + /// Explicitly sets the builder to produce vectorized IR. + void generate_vector_ir() { + vectorize = true; + } + + /// Sets the current function for which LLVM IR is generated. + void set_function(llvm::Function* function) { + current_function = function; + } + + /// Clears the stack of the values and unsets the current function. + void clear_function() { + value_stack.clear(); + current_function = nullptr; + alloca_ip = nullptr; + } + + /// Sets the value to be the mask for vector code generation. + void set_mask(llvm::Value* value) { + mask = value; + } + + /// Clears the mask for vector code generation. + void clear_mask() { + mask = nullptr; + } + + /// Indicates whether the vectorized IR is predicated. + bool generates_predicated_ir() { + return vectorize && mask; + } + + /// Generates LLVM IR to allocate the arguments of the function on the stack. + void allocate_function_arguments(llvm::Function* function, + const ast::CodegenVarWithTypeVector& nmodl_arguments); + + llvm::Value* create_alloca(const std::string& name, llvm::Type* type); + + /// Generates IR for allocating an array. + void create_array_alloca(const std::string& name, llvm::Type* element_type, int num_elements); + + /// Generates LLVM IR for the given binary operator. + void create_binary_op(llvm::Value* lhs, llvm::Value* rhs, ast::BinaryOp op); + + /// Generates LLVM IR for the bitcast instruction. + llvm::Value* create_bitcast(llvm::Value* value, llvm::Type* dst_type); + + /// Create a basic block and set the builder's insertion point to it. + llvm::BasicBlock* create_block_and_set_insertion_point( + llvm::Function* function, + llvm::BasicBlock* insert_before = nullptr, + std::string name = ""); + + /// Generates LLVM IR for unconditional branch. + void create_br(llvm::BasicBlock* block); + + /// Generates LLVM IR for unconditional branch and sets the insertion point to this block. + void create_br_and_set_insertion_point(llvm::BasicBlock* block); + + /// Generates LLVM IR for conditional branch. + llvm::BranchInst* create_cond_br(llvm::Value* condition, + llvm::BasicBlock* true_block, + llvm::BasicBlock* false_block); + + /// Generates LLVM IR for the boolean constant. + void create_boolean_constant(int value); + + /// Generates LLVM IR for the floating-point constant. + void create_fp_constant(const std::string& value); + + /// Generates LLVM IR for a call to the function. + void create_function_call(llvm::Function* callee, + ValueVector& arguments, + bool use_result = true); + + /// Generates LLVM IR for the string value. + llvm::Value* create_global_string(const ast::String& node); + + /// Generates LLVM IR to transform the value into an index by possibly sign-extending it. + llvm::Value* create_index(llvm::Value* value); + + /// Generates an intrinsic that corresponds to the given name. + void create_intrinsic(const std::string& name, + ValueVector& argument_values, + TypeVector& argument_types); + + /// Generates LLVM IR for the integer constant. + void create_i32_constant(int value); + + /// Generates LLVM IR to load the value specified by its name and returns it. + llvm::Value* create_load(const std::string& name, bool masked = false); + + /// Generates LLVM IR to load the value from the pointer and returns it. + llvm::Value* create_load(llvm::Value* ptr, bool masked = false); + + /// Generates LLVM IR to load the element at the specified index from the given array name and + /// returns it. + llvm::Value* create_load_from_array(const std::string& name, llvm::Value* index); + + /// Generates LLVM IR to store the value to the location specified by the name. + void create_store(const std::string& name, llvm::Value* value, bool masked = false); + + /// Generates LLVM IR to store the value to the location specified by the pointer. + void create_store(llvm::Value* ptr, llvm::Value* value, bool masked = false); + + /// Generates LLVM IR to store the value to the array element, where array is specified by the + /// name. + void create_store_to_array(const std::string& name, llvm::Value* index, llvm::Value* value); + + /// Generates LLVM IR return instructions. + void create_return(llvm::Value* return_value = nullptr); + + /// Generates IR for allocating a scalar or vector variable. + void create_scalar_or_vector_alloca(const std::string& name, + llvm::Type* element_or_scalar_type); + + /// Creates a call to llvm.vscale.i32(). + void create_vscale_call(llvm::Module& module); + + /// Generates LLVM IR for the given unary operator. + void create_unary_op(llvm::Value* value, ast::UnaryOp op); + + /// Creates a boolean (1-bit integer) type. + llvm::Type* get_boolean_type(); + + /// Returns current basic block. + llvm::BasicBlock* get_current_block(); + + /// Returns the name of the function for which LLVM IR is generated. + std::string get_current_function_name(); + + /// Creates a pointer to 8-bit integer type. + llvm::Type* get_i8_ptr_type(); + + /// Creates a 32-bit integer type. + llvm::Type* get_i32_type(); + + /// Creates a pointer to 32-bit integer type. + llvm::Type* get_i32_ptr_type(); + + /// Creates a 64-bit integer type. + llvm::Type* get_i64_type(); + + /// Creates a floating-point type. + llvm::Type* get_fp_type(); + + /// Creates a pointer to floating-point type. + llvm::Type* get_fp_ptr_type(); + + /// Creates a void type. + llvm::Type* get_void_type(); + + /// Generates LLVM IR to get the address of the struct's member at given index. Returns the + /// calculated value. + llvm::Value* get_struct_member_ptr(llvm::Value* struct_variable, int member_index); + + /// Creates a pointer to struct type with the given name and given members. + llvm::Type* get_struct_ptr_type(const std::string& struct_type_name, TypeVector& member_types); + + /// Inverts the mask for vector code generation by xoring it. + void invert_mask(); + + /// Generates IR that loads the elements of the array even during vectorization. If the value is + /// specified, then it is stored to the array at the given index. + llvm::Value* load_to_or_store_from_array(const std::string& id_name, + llvm::Value* id_value, + llvm::Value* array, + llvm::Value* maybe_value_to_store = nullptr); + + /// Lookups the value by its name in the current function's symbol table. + llvm::Value* lookup_value(const std::string& value_name); + + /// Generates IR to replicate the value if vectorizing the code. + void maybe_replicate_value(llvm::Value* value); + + /// Sets builder's insertion point to the given block. + void set_insertion_point(llvm::BasicBlock* block); + + /// Sets the necessary attributes for the kernel and its arguments. + void set_kernel_attributes(); + + /// Sets the loop metadata for the given branch from the loop. + void set_loop_metadata(llvm::BranchInst* branch); + + /// Pops the last visited value from the value stack. + llvm::Value* pop_last_value(); + + private: + /// Generates an inbounds GEP instruction for the given name and returns calculated address. + llvm::Value* create_inbounds_gep(const std::string& variable_name, llvm::Value* index); + + /// Generates an inbounds GEP instruction for the given value and returns calculated address. + llvm::Value* create_inbounds_gep(llvm::Value* variable, llvm::Value* index); + + /// Returns a scalar constant of the provided type. + template + llvm::Value* get_scalar_constant(llvm::Type* type, V value); + + /// Returns a vector constant of the provided type. + template + llvm::Value* get_vector_constant(llvm::Type* type, V value); + + /// Creates a Fixed or Scalable vector type. + llvm::Type* get_vector_type(llvm::Type* element_type, unsigned min_num_elements); +}; + +} // namespace codegen +} // namespace nmodl diff --git a/src/codegen/llvm/main.cpp b/src/codegen/llvm/main.cpp new file mode 100644 index 0000000000..2f4e1f653d --- /dev/null +++ b/src/codegen/llvm/main.cpp @@ -0,0 +1,75 @@ +/************************************************************************* + * Copyright (C) 2018-2021 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include + +#include "ast/program.hpp" +#include "codegen/llvm/codegen_llvm_visitor.hpp" +#include "parser/nmodl_driver.hpp" +#include "test/benchmark/jit_driver.hpp" +#include "utils/logger.hpp" +#include "visitors/symtab_visitor.hpp" + +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" + +using namespace nmodl; +using namespace runner; + +int main(int argc, const char* argv[]) { + CLI::App app{ + "NMODL LLVM Runner : Executes functions from a MOD file via LLVM IR code generation"}; + + // Currently, only a single MOD file is supported, as well as an entry point with a double + // return type. While returning a double value is a general case in NMODL, it will be nice to + // have a more generic functionality. \todo: Add support for different return types (int, void). + + std::string filename; + std::string entry_point_name = "main"; + + app.add_option("-f,--file,file", filename, "A single MOD file source") + ->required() + ->check(CLI::ExistingFile); + app.add_option("-e,--entry-point,entry-point", + entry_point_name, + "An entry point function from the MOD file"); + + CLI11_PARSE(app, argc, argv); + + logger->info("Parsing MOD file to AST"); + parser::NmodlDriver driver; + const auto& ast = driver.parse_file(filename); + + logger->info("Running Symtab Visitor"); + visitor::SymtabVisitor().visit_program(*ast); + + logger->info("Running LLVM Visitor"); + codegen::CodegenLLVMVisitor llvm_visitor(filename, /*output_dir=*/".", /*opt_passes=*/false); + llvm_visitor.visit_program(*ast); + std::unique_ptr module = llvm_visitor.get_module(); + + // Check if the entry-point is valid for JIT driver to execute. + auto func = module->getFunction(entry_point_name); + if (!func) + throw std::runtime_error("Error: entry-point is not found\n"); + + if (func->getNumOperands() != 0) + throw std::runtime_error("Error: entry-point functions with arguments are not supported\n"); + + if (!func->getReturnType()->isDoubleTy()) + throw std::runtime_error( + "Error: entry-point functions with non-double return type are not supported\n"); + + TestRunner runner(std::move(module)); + runner.initialize_driver(); + + // Since only double type is supported, provide explicit double type to the running function. + auto r = runner.run_without_arguments(entry_point_name); + fprintf(stderr, "Result: %f\n", r); + + return 0; +} diff --git a/src/language/code_generator.cmake b/src/language/code_generator.cmake index 400b969a23..46dc01ea9f 100644 --- a/src/language/code_generator.cmake +++ b/src/language/code_generator.cmake @@ -65,6 +65,16 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/block_comment.hpp ${PROJECT_BINARY_DIR}/src/ast/boolean.hpp ${PROJECT_BINARY_DIR}/src/ast/breakpoint_block.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_atomic_statement.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_for_statement.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_function.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_instance_var.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_return_statement.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_struct.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_var.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_var_list_statement.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_var_type.hpp + ${PROJECT_BINARY_DIR}/src/ast/codegen_var_with_type.hpp ${PROJECT_BINARY_DIR}/src/ast/compartment.hpp ${PROJECT_BINARY_DIR}/src/ast/conductance_hint.hpp ${PROJECT_BINARY_DIR}/src/ast/conserve.hpp @@ -108,6 +118,7 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/independent_definition.hpp ${PROJECT_BINARY_DIR}/src/ast/indexed_name.hpp ${PROJECT_BINARY_DIR}/src/ast/initial_block.hpp + ${PROJECT_BINARY_DIR}/src/ast/instance_struct.hpp ${PROJECT_BINARY_DIR}/src/ast/integer.hpp ${PROJECT_BINARY_DIR}/src/ast/kinetic_block.hpp ${PROJECT_BINARY_DIR}/src/ast/lag_statement.hpp @@ -185,6 +196,7 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/valence.hpp ${PROJECT_BINARY_DIR}/src/ast/var_name.hpp ${PROJECT_BINARY_DIR}/src/ast/verbatim.hpp + ${PROJECT_BINARY_DIR}/src/ast/void.hpp ${PROJECT_BINARY_DIR}/src/ast/watch.hpp ${PROJECT_BINARY_DIR}/src/ast/watch_statement.hpp ${PROJECT_BINARY_DIR}/src/ast/while_statement.hpp diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 63762a9be0..3dc802c982 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -28,7 +28,51 @@ - Expression: children: - Number: + - Void: + nmodl: "VOID" + brief: "Represent void type in code generation" - Identifier: + children: + - CodegenVarType: + brief: "Represent type of the variable" + members: + - type: + brief: "Type of the ast node" + type: AstNodeType + - CodegenVar: + brief: "Represent variable used for code generation" + members: + - pointer: + brief: "If variable is pointer type" + type: int + - name: + brief: "Name of the variable" + type: Identifier + node_name: true + - CodegenVarWithType: + brief: "Represent variable used for code generation" + members: + - type: + brief: "Type of the variable" + type: CodegenVarType + suffix: {value: " "} + - is_pointer: + brief: "If variable is pointer type" + type: int + - name: + brief: "Name of the variable" + type: Identifier + node_name: true + - CodegenInstanceVar: + brief: "Represent instance variable" + members: + - instance_var: + brief: "Instance variable" + type: Name + suffix: {value: "->"} + - member_var: + brief: "Member variable within instance" + type: Identifier - Block: children: - NrnStateBlock: @@ -89,7 +133,41 @@ type: StatementBlock - finalize_block: brief: "Statement block to be executed after calling linear solver" - type: StatementBlock + type: StatementBlock + - CodegenFunction: + brief: "Function generated from FUNCTION or PROCEDURE block" + members: + - return_type: + brief: "Return type of the function" + type: CodegenVarType + suffix: {value: " "} + - name: + brief: "Name of the function" + type: Name + node_name: true + - arguments: + brief: "Vector of the parameters to the function" + type: CodegenVarWithType + vector: true + prefix: {value: "(", force: true} + suffix: {value: ")", force: true} + separator: ", " + - statement_block: + brief: "Body of the function" + type: StatementBlock + getter: {override: true} + - InstanceStruct: + nmodl: "INSTANCE_STRUCT " + members: + - codegen_vars: + brief: "Vector of CodegenVars" + type: CodegenVarWithType + vector: true + add: true + separator: "\\n " + prefix: {value: "{\\n ", force: true} + suffix: {value: "\\n}", force: true} + brief: "LLVM IR Struct that holds the mechanism instance's variables" - WrappedExpression: brief: "Wrap any other expression type" members: @@ -110,4 +188,92 @@ - node_to_solve: brief: "Block to be solved (callback node or solution node itself)" type: Expression + - CodegenStruct: + brief: "Represent a struct or class for code generation" + members: + - variable_statements: + brief: "member variables of the class/struct" + type: CodegenVarListStatement + vector: true + - functions: + brief: "member functions of the class/struct" + type: CodegenFunction + vector: true - Statement: + children: + - CodegenForStatement: + brief: "Represent for loop used for code generation" + nmodl: "for(" + members: + - initialization: + brief: "initialization expression for the loop" + type: Expression + optional: true + - condition: + brief: "condition expression for the loop" + type: Expression + optional: true + prefix: {value: "; "} + suffix: {value: "; "} + - increment: + brief: "increment or decrement expression for the loop" + type: Expression + optional: true + suffix: {value: ") "} + - statement_block: + brief: "body of the loop" + type: StatementBlock + getter: {override: true} + - CodegenReturnStatement: + brief: "Represent return statement for code generation" + nmodl: "return " + members: + - statement: + brief: "return statement" + type: Expression + optional: true + - CodegenVarListStatement: + brief: "Represent list of variables used for code generation" + members: + - var_type: + brief: "Type of the variables" + type: CodegenVarType + suffix: {value: " "} + - variables: + brief: "List of the variables to define" + type: CodegenVar + vector: true + separator: ", " + add: true + - CodegenAtomicStatement: + brief: "Represent atomic operation" + description: | + During code generation certain operations like ion updates, vec_rhs or + vec_d updates (for synapse) needs to be atomic operations if executed by + multiple threads. In case of SIMD, there are conflicts for `vec_d` and + `vec_rhs` for synapse types. Here are some statements from C++ backend: + + \code{.cpp} + vec_d[node_id] += g + vec_rhs[node_id] -= rhs + ion_ina[indexes[some_index]] += ina[id] + ion_cai[indexes[some_index]] = cai[id] // cai here is state variable + \endcode + + These operations will be represented by atomic statement node type: + * `vec_d[node_id]` : lhs + * `+=` : atomic_op + * `g` : rhs + + members: + - lhs: + brief: "Variable to be updated atomically" + type: Identifier + - atomic_op: + brief: "Operator" + type: BinaryOperator + prefix: {value: " "} + suffix: {value: " "} + - rhs: + brief: "Expression for atomic operation" + type: Expression diff --git a/src/language/nmodl.yaml b/src/language/nmodl.yaml index 0724f81e29..54da340b7b 100644 --- a/src/language/nmodl.yaml +++ b/src/language/nmodl.yaml @@ -1374,7 +1374,7 @@ type: Double - Statement: - brief: "TODO" + brief: "Base class to represent a statement in the NMODL" children: - UnitState: brief: "TODO" diff --git a/src/language/node_info.py b/src/language/node_info.py index f4fb599347..57833af229 100644 --- a/src/language/node_info.py +++ b/src/language/node_info.py @@ -29,6 +29,7 @@ "QueueType", "BAType", "UnitStateType", + "AstNodeType", } BASE_TYPES = {"std::string" } | INTEGRAL_TYPES @@ -167,6 +168,9 @@ STATEMENT_BLOCK_NODE = "StatementBlock" STRING_NODE = "String" UNIT_BLOCK = "UnitBlock" +AST_NODETYPE_NODE= "AstNodeType" +CODEGEN_VAR_TYPE_NODE = "CodegenVarType" +CODEGEN_VAR_WITH_TYPE_NODE = "CodegenVarWithType" # name of variable in prime node which represent order of derivative ORDER_VAR_NAME = "order" diff --git a/src/language/nodes.py b/src/language/nodes.py index a539b55647..cf7aa1f30b 100644 --- a/src/language/nodes.py +++ b/src/language/nodes.py @@ -147,6 +147,18 @@ def is_boolean_node(self): def is_name_node(self): return self.class_name == node_info.NAME_NODE + @property + def is_ast_nodetype_node(self): + return self.class_name == node_info.AST_NODETYPE_NODE + + @property + def is_codegen_var_type_node(self): + return self.class_name == node_info.CODEGEN_VAR_TYPE_NODE + + @property + def is_codegen_var_with_type_node(self): + return self.class_name == node_info.CODEGEN_VAR_WITH_TYPE_NODE + @property def is_enum_node(self): data_type = node_info.DATA_TYPES[self.class_name] diff --git a/src/language/templates/ast/ast_decl.hpp b/src/language/templates/ast/ast_decl.hpp index cbca65e692..546c7dcb40 100644 --- a/src/language/templates/ast/ast_decl.hpp +++ b/src/language/templates/ast/ast_decl.hpp @@ -12,7 +12,9 @@ #pragma once #include +#include #include +#include /// \file /// \brief Auto generated AST node types and aliases declaration @@ -50,6 +52,15 @@ enum class AstNodeType { /** @} */ // end of ast_type +static inline std::string to_string(AstNodeType type) { + {% for node in nodes %} + if(type == AstNodeType::{{ node.class_name|snake_case|upper }}) { + return "{{ node.class_name|snake_case|upper }}"; + } + {% endfor %} + throw std::runtime_error("Unhandled type in to_string(AstNodeType type)!"); +} + /** * @defgroup ast_vec_type AST Vector Type Aliases * @ingroup ast diff --git a/src/language/templates/visitors/json_visitor.cpp b/src/language/templates/visitors/json_visitor.cpp index e96bcbf10c..2a0c6d68a9 100644 --- a/src/language/templates/visitors/json_visitor.cpp +++ b/src/language/templates/visitors/json_visitor.cpp @@ -22,33 +22,40 @@ using namespace ast; {% for node in nodes %} void JSONVisitor::visit_{{ node.class_name|snake_case }}(const {{ node.class_name }}& node) { {% if node.has_children() %} - printer->push_block(node.get_node_type_name()); - if (embed_nmodl) { - printer->add_block_property("nmodl", to_nmodl(node)); - } - node.visit_children(*this); - {% if node.is_data_type_node %} + printer->push_block(node.get_node_type_name()); + if (embed_nmodl) { + printer->add_block_property("nmodl", to_nmodl(node)); + } + node.visit_children(*this); + {% if node.is_data_type_node %} {% if node.is_integer_node %} - if(!node.get_macro()) { - std::stringstream ss; - ss << node.eval(); - printer->add_node(ss.str()); - } + if(!node.get_macro()) { + std::stringstream ss; + ss << node.eval(); + printer->add_node(ss.str()); + } {% else %} - std::stringstream ss; - ss << node.eval(); - printer->add_node(ss.str()); + std::stringstream ss; + ss << node.eval(); + printer->add_node(ss.str()); {% endif %} {% endif %} - printer->pop_block(); + + {% if node.is_codegen_var_type_node %} + printer->add_node(ast::to_string(node.get_type())); + {% endif %} + + printer->pop_block(); + {% if node.is_program_node %} - if (node.get_parent() == nullptr) { - flush(); - } + if (node.get_parent() == nullptr) { + flush(); + } {% endif %} + {% else %} - (void)node; - printer->add_node("{{ node.class_name }}"); + (void)node; + printer->add_node("{{ node.class_name }}"); {% endif %} } diff --git a/src/language/templates/visitors/nmodl_visitor.cpp b/src/language/templates/visitors/nmodl_visitor.cpp index a69c3b0b26..01b470e70d 100644 --- a/src/language/templates/visitors/nmodl_visitor.cpp +++ b/src/language/templates/visitors/nmodl_visitor.cpp @@ -115,7 +115,15 @@ void NmodlPrintVisitor::visit_{{ node.class_name|snake_case}}(const {{ node.clas {% endif %} {% for child in node.children %} {% call guard(child.force_prefix, child.force_suffix) -%} - {% if child.is_base_type_node %} + + {% if node.is_codegen_var_with_type_node and child.varname == "is_pointer" %} + if(node.get_{{ child.varname }}()) { + printer->add_element("*"); + } + {% elif child.is_base_type_node %} + {% if child.is_ast_nodetype_node %} + printer->add_element(ast::to_string(node.get_{{child.varname}}())); + {% endif %} {% else %} {% if child.optional or child.is_statement_block_node %} if(node.get_{{ child.varname }}()) { diff --git a/src/main.cpp b/src/main.cpp index 60e933f052..c700703739 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -17,6 +17,12 @@ #include "codegen/codegen_cuda_visitor.hpp" #include "codegen/codegen_ispc_visitor.hpp" #include "codegen/codegen_omp_visitor.hpp" + +#ifdef NMODL_LLVM_BACKEND +#include "codegen/llvm/codegen_llvm_visitor.hpp" +#include "test/benchmark/llvm_benchmark.hpp" +#endif + #include "config/config.h" #include "parser/nmodl_driver.hpp" #include "pybind/pyembed.hpp" @@ -82,6 +88,9 @@ int main(int argc, const char* argv[]) { /// true if cuda code to be generated bool cuda_backend(false); + /// true if llvm code to be generated + bool llvm_backend(false); + /// true if sympy should be used for solving ODEs analytically bool sympy_analytic(false); @@ -155,6 +164,53 @@ int main(int argc, const char* argv[]) { /// floating point data type std::string data_type("double"); +#ifdef NMODL_LLVM_BACKEND + /// generate llvm IR + bool llvm_ir(false); + + /// use single precision floating-point types + bool llvm_float_type(false); + + /// run llvm optimisation passes + bool llvm_ir_opt_passes(false); + + /// generate IR for scalable vector ISAs + bool llvm_scalable_vectors(false); + + /// llvm vector width + int llvm_vec_width = 1; + + /// vector library name + std::string vector_library("none"); + + /// disable debug information generation for the IR + bool disable_debug_information(false); + + /// fast math flags for LLVM backend + std::vector llvm_fast_math_flags; + + /// run llvm benchmark + bool run_llvm_benchmark(false); + + /// optimisation level for IR generation + int llvm_opt_level_ir = 0; + + /// optimisation level for machine code generation + int llvm_opt_level_codegen = 0; + + /// list of shared libraries to link against in JIT + std::vector shared_lib_paths; + + /// the size of the instance struct for the benchmark + int instance_size = 10000; + + /// the number of repeated experiments for the benchmarking + int num_experiments = 100; + + /// specify the backend for LLVM IR to target + std::string backend = "default"; +#endif + app.get_formatter()->column_width(40); app.set_help_all_flag("-H,--help-all", "Print this help message including all sub-commands"); @@ -258,6 +314,59 @@ int main(int argc, const char* argv[]) { optimize_ionvar_copies_codegen, "Optimize copies of ion variables ({})"_format(optimize_ionvar_copies_codegen))->ignore_case(); +#ifdef NMODL_LLVM_BACKEND + + // LLVM IR code generation options. + auto llvm_opt = app.add_subcommand("llvm", "LLVM code generation option")->ignore_case(); + llvm_opt->add_flag("--ir", + llvm_ir, + "Generate LLVM IR ({})"_format(llvm_ir))->ignore_case(); + llvm_opt->add_flag("--disable-debug-info", + disable_debug_information, + "Disable debug information ({})"_format(disable_debug_information))->ignore_case(); + llvm_opt->add_flag("--opt", + llvm_ir_opt_passes, + "Run few common LLVM IR optimisation passes ({})"_format(llvm_ir_opt_passes))->ignore_case(); + llvm_opt->add_flag("--single-precision", + llvm_float_type, + "Use single precision floating-point types ({})"_format(llvm_float_type))->ignore_case(); + llvm_opt->add_flag("--scalable", + llvm_scalable_vectors, + "Generate scalable vector IR ({})"_format(llvm_scalable_vectors))->ignore_case(); + llvm_opt->add_option("--vector-width", + llvm_vec_width, + "LLVM explicit vectorisation width ({})"_format(llvm_vec_width))->ignore_case(); + llvm_opt->add_option("--veclib", + vector_library, + "Vector library for maths functions ({})"_format(vector_library))->check(CLI::IsMember({"Accelerate", "libsystem_m", "libmvec", "MASSV", "SLEEF", "SVML", "none"})); + llvm_opt->add_option("--fmf", + llvm_fast_math_flags, + "Fast math flags for floating-point optimizations (none)")->check(CLI::IsMember({"afn", "arcp", "contract", "ninf", "nnan", "nsz", "reassoc", "fast"})); + + // LLVM IR benchmark options. + auto benchmark_opt = app.add_subcommand("benchmark", "LLVM benchmark option")->ignore_case(); + benchmark_opt->add_flag("--run", + run_llvm_benchmark, + "Run LLVM benchmark ({})"_format(run_llvm_benchmark))->ignore_case(); + benchmark_opt->add_option("--opt-level-ir", + llvm_opt_level_ir, + "LLVM IR optimisation level (O{})"_format(llvm_opt_level_ir))->ignore_case()->check(CLI::IsMember({"0", "1", "2", "3"})); + benchmark_opt->add_option("--opt-level-codegen", + llvm_opt_level_codegen, + "Machine code optimisation level (O{})"_format(llvm_opt_level_codegen))->ignore_case()->check(CLI::IsMember({"0", "1", "2", "3"})); + benchmark_opt->add_option("--libs", shared_lib_paths, "Shared libraries to link IR against") + ->ignore_case() + ->check(CLI::ExistingFile); + benchmark_opt->add_option("--instance-size", + instance_size, + "Instance struct size ({})"_format(instance_size))->ignore_case(); + benchmark_opt->add_option("--repeat", + num_experiments, + "Number of experiments for benchmarking ({})"_format(num_experiments))->ignore_case(); + benchmark_opt->add_option("--backend", + backend, + "Target's backend ({})"_format(backend))->ignore_case()->check(CLI::IsMember({"avx2", "default", "sse2"})); +#endif // clang-format on CLI11_PARSE(app, argc, argv); @@ -286,15 +395,24 @@ int main(int argc, const char* argv[]) { } }; + /// write ast to nmodl + const auto ast_to_json = [json_ast](ast::Program& ast, const std::string& filepath) { + if (json_ast) { + JSONVisitor(filepath).write(ast); + logger->info("AST to JSON transformation written to {}", filepath); + } + }; + for (const auto& file: mod_files) { logger->info("Processing {}", file); const auto modfile = utils::remove_extension(utils::base_name(file)); /// create file path for nmodl file - auto filepath = [scratch_dir, modfile](const std::string& suffix) { + auto filepath = [scratch_dir, modfile](const std::string& suffix, const std::string& ext) { static int count = 0; - return "{}/{}.{}.{}.mod"_format(scratch_dir, modfile, std::to_string(count++), suffix); + return "{}/{}.{}.{}.{}"_format( + scratch_dir, modfile, std::to_string(count++), suffix, ext); }; /// driver object creates lexer and parser, just call parser method @@ -320,7 +438,7 @@ int main(int argc, const char* argv[]) { { logger->info("Running CVode to cnexp visitor"); AfterCVodeToCnexpVisitor().visit_program(*ast); - ast_to_nmodl(*ast, filepath("after_cvode_to_cnexp")); + ast_to_nmodl(*ast, filepath("after_cvode_to_cnexp", "mod")); } /// Rename variables that match ISPC compiler double constants @@ -328,7 +446,7 @@ int main(int argc, const char* argv[]) { logger->info("Running ISPC variables rename visitor"); IspcRenameVisitor(ast).visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("ispc_double_rename")); + ast_to_nmodl(*ast, filepath("ispc_double_rename", "mod")); } /// GLOBAL to RANGE rename visitor @@ -341,7 +459,7 @@ int main(int argc, const char* argv[]) { logger->info("Running GlobalToRange visitor"); GlobalToRangeVisitor(ast).visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("global_to_range")); + ast_to_nmodl(*ast, filepath("global_to_range", "mod")); } /// LOCAL to ASSIGNED visitor @@ -350,7 +468,7 @@ int main(int argc, const char* argv[]) { PerfVisitor().visit_program(*ast); LocalToAssignedVisitor().visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("local_to_assigned")); + ast_to_nmodl(*ast, filepath("local_to_assigned", "mod")); } { @@ -376,31 +494,26 @@ int main(int argc, const char* argv[]) { symtab->print(std::cout); } - ast_to_nmodl(*ast, filepath("ast")); - - if (json_ast) { - auto file = scratch_dir + "/" + modfile + ".ast.json"; - logger->info("Writing AST into {}", file); - JSONVisitor(file).write(*ast); - } + ast_to_nmodl(*ast, filepath("ast", "mod")); + ast_to_json(*ast, filepath("ast", "json")); if (verbatim_rename) { logger->info("Running verbatim rename visitor"); VerbatimVarRenameVisitor().visit_program(*ast); - ast_to_nmodl(*ast, filepath("verbatim_rename")); + ast_to_nmodl(*ast, filepath("verbatim_rename", "mod")); } if (nmodl_const_folding) { logger->info("Running nmodl constant folding visitor"); ConstantFolderVisitor().visit_program(*ast); - ast_to_nmodl(*ast, filepath("constfold")); + ast_to_nmodl(*ast, filepath("constfold", "mod")); } if (nmodl_unroll) { logger->info("Running nmodl loop unroll visitor"); LoopUnrollVisitor().visit_program(*ast); ConstantFolderVisitor().visit_program(*ast); - ast_to_nmodl(*ast, filepath("unroll")); + ast_to_nmodl(*ast, filepath("unroll", "mod")); SymtabVisitor(update_symtab).visit_program(*ast); } @@ -412,7 +525,7 @@ int main(int argc, const char* argv[]) { auto kineticBlockVisitor = KineticBlockVisitor(); kineticBlockVisitor.visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - const auto filename = filepath("kinetic"); + const auto filename = filepath("kinetic", "mod"); ast_to_nmodl(*ast, filename); if (nmodl_ast && kineticBlockVisitor.get_conserve_statement_count()) { logger->warn( @@ -425,7 +538,7 @@ int main(int argc, const char* argv[]) { logger->info("Running STEADYSTATE visitor"); SteadystateVisitor().visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("steadystate")); + ast_to_nmodl(*ast, filepath("steadystate", "mod")); } /// Parsing units fron "nrnunits.lib" and mod files @@ -442,14 +555,14 @@ int main(int argc, const char* argv[]) { if (nmodl_inline) { logger->info("Running nmodl inline visitor"); InlineVisitor().visit_program(*ast); - ast_to_nmodl(*ast, filepath("inline")); + ast_to_nmodl(*ast, filepath("inline", "mod")); } if (local_rename) { logger->info("Running local variable rename visitor"); LocalVarRenameVisitor().visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("local_rename")); + ast_to_nmodl(*ast, filepath("local_rename", "mod")); } if (nmodl_localize) { @@ -458,14 +571,14 @@ int main(int argc, const char* argv[]) { LocalizeVisitor(localize_verbatim).visit_program(*ast); LocalVarRenameVisitor().visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("localize")); + ast_to_nmodl(*ast, filepath("localize", "mod")); } if (sympy_conductance) { logger->info("Running sympy conductance visitor"); SympyConductanceVisitor().visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("sympy_conductance")); + ast_to_nmodl(*ast, filepath("sympy_conductance", "mod")); } if (sympy_analytic || sparse_solver_exists(*ast)) { @@ -476,19 +589,19 @@ int main(int argc, const char* argv[]) { logger->info("Running sympy solve visitor"); SympySolverVisitor(sympy_pade, sympy_cse).visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("sympy_solve")); + ast_to_nmodl(*ast, filepath("sympy_solve", "mod")); } { logger->info("Running cnexp visitor"); NeuronSolveVisitor().visit_program(*ast); - ast_to_nmodl(*ast, filepath("cnexp")); + ast_to_nmodl(*ast, filepath("cnexp", "mod")); } { SolveBlockVisitor().visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("solveblock")); + ast_to_nmodl(*ast, filepath("solveblock", "mod")); } if (json_perfstat) { @@ -548,6 +661,38 @@ int main(int argc, const char* argv[]) { optimize_ionvar_copies_codegen); visitor.visit_program(*ast); } + +#ifdef NMODL_LLVM_BACKEND + if (llvm_ir || run_llvm_benchmark) { + logger->info("Running LLVM backend code generator"); + CodegenLLVMVisitor visitor(modfile, + output_dir, + llvm_ir_opt_passes, + llvm_float_type, + llvm_vec_width, + vector_library, + !disable_debug_information, + llvm_fast_math_flags, + llvm_scalable_vectors); + visitor.visit_program(*ast); + ast_to_nmodl(*ast, filepath("llvm", "mod")); + ast_to_json(*ast, filepath("llvm", "json")); + + if (run_llvm_benchmark) { + logger->info("Running LLVM benchmark"); + benchmark::LLVMBenchmark benchmark(visitor, + modfile, + output_dir, + shared_lib_paths, + num_experiments, + instance_size, + backend, + llvm_opt_level_ir, + llvm_opt_level_codegen); + benchmark.run(ast); + } + } +#endif } } diff --git a/src/visitors/inline_visitor.cpp b/src/visitors/inline_visitor.cpp index cb723f0f1c..d628233dd7 100644 --- a/src/visitors/inline_visitor.cpp +++ b/src/visitors/inline_visitor.cpp @@ -298,6 +298,8 @@ void InlineVisitor::visit_statement_block(StatementBlock& node) { /** Visit all wrapped expressions which can contain function calls. * If a function call is replaced then the wrapped expression is * also replaced with new variable node from the inlining result. + * Note that we use `VarName` so that LHS of assignment expression + * is `VarName`, similar to parser. */ void InlineVisitor::visit_wrapped_expression(WrappedExpression& node) { node.visit_children(*this); @@ -306,7 +308,9 @@ void InlineVisitor::visit_wrapped_expression(WrappedExpression& node) { auto expression = dynamic_cast(e.get()); if (replaced_fun_calls.find(expression) != replaced_fun_calls.end()) { auto var = replaced_fun_calls[expression]; - node.set_expression(std::make_shared(new String(var))); + node.set_expression(std::make_shared(new Name(new String(var)), + /*at=*/nullptr, + /*index=*/nullptr)); } } } diff --git a/test/benchmark/CMakeLists.txt b/test/benchmark/CMakeLists.txt new file mode 100644 index 0000000000..4441d53251 --- /dev/null +++ b/test/benchmark/CMakeLists.txt @@ -0,0 +1,17 @@ +# ============================================================================= +# llvm benchmark sources +# ============================================================================= +set(LLVM_BENCHMARK_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/jit_driver.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_driver.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/llvm_benchmark.cpp ${CMAKE_CURRENT_SOURCE_DIR}/llvm_benchmark.hpp) + +# ============================================================================= +# LLVM benchmark library +# ============================================================================= +include_directories(${LLVM_INCLUDE_DIRS}) +add_library(llvm_benchmark STATIC ${LLVM_BENCHMARK_SOURCE_FILES}) +add_dependencies(llvm_benchmark lexer util visitor) + +if(NMODL_ENABLE_JIT_EVENT_LISTENERS) + target_compile_definitions(llvm_benchmark PUBLIC NMODL_HAVE_JIT_EVENT_LISTENERS) +endif() diff --git a/test/benchmark/jit_driver.cpp b/test/benchmark/jit_driver.cpp new file mode 100644 index 0000000000..a2d8df63f4 --- /dev/null +++ b/test/benchmark/jit_driver.cpp @@ -0,0 +1,263 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include "jit_driver.hpp" +#include "codegen/llvm/codegen_llvm_visitor.hpp" +#include "utils/common_utils.hpp" + +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/ExecutionEngine/JITEventListener.h" +#include "llvm/ExecutionEngine/ObjectCache.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/LLJIT.h" +#include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/InitializePasses.h" +#include "llvm/Support/Host.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" + +namespace nmodl { +namespace runner { + +/****************************************************************************************/ +/* Utilities for JIT driver */ +/****************************************************************************************/ + +/// Initialises some LLVM optimisation passes. +static void initialise_optimisation_passes() { + auto& registry = *llvm::PassRegistry::getPassRegistry(); + llvm::initializeCore(registry); + llvm::initializeTransformUtils(registry); + llvm::initializeScalarOpts(registry); + llvm::initializeInstCombine(registry); + llvm::initializeAnalysis(registry); +} + +/// Populates pass managers with passes for the given optimisation levels. +static void populate_pms(llvm::legacy::FunctionPassManager& func_pm, + llvm::legacy::PassManager& module_pm, + int opt_level, + int size_level, + llvm::TargetMachine* tm) { + // First, set the pass manager builder with some basic optimisation information. + llvm::PassManagerBuilder pm_builder; + pm_builder.OptLevel = opt_level; + pm_builder.SizeLevel = size_level; + pm_builder.DisableUnrollLoops = opt_level == 0; + + // If target machine is defined, then initialise the TargetTransformInfo for the target. + if (tm) { + module_pm.add(createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); + func_pm.add(createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); + } + + // Populate pass managers. + pm_builder.populateModulePassManager(module_pm); + pm_builder.populateFunctionPassManager(func_pm); +} + +/// Runs the function and module passes on the provided module. +static void run_optimisation_passes(llvm::Module& module, + llvm::legacy::FunctionPassManager& func_pm, + llvm::legacy::PassManager& module_pm) { + func_pm.doInitialization(); + auto& functions = module.getFunctionList(); + for (auto& function: functions) { + llvm::verifyFunction(function); + func_pm.run(function); + } + func_pm.doFinalization(); + module_pm.run(module); +} + +/// Optimises the given LLVM IR module. +static void optimise_module(llvm::Module& module, + int opt_level, + llvm::TargetMachine* tm = nullptr) { + llvm::legacy::FunctionPassManager func_pm(&module); + llvm::legacy::PassManager module_pm; + populate_pms(func_pm, module_pm, opt_level, /*size_level=*/0, tm); + run_optimisation_passes(module, func_pm, module_pm); +} + +/// Sets the target triple and the data layout of the module. +static void set_triple_and_data_layout(llvm::Module& module, const std::string& features) { + // Get the default target triple for the host. + auto target_triple = llvm::sys::getDefaultTargetTriple(); + std::string error_msg; + auto* target = llvm::TargetRegistry::lookupTarget(target_triple, error_msg); + if (!target) + throw std::runtime_error("Error " + error_msg + "\n"); + + // Get the CPU information and set a target machine to create the data layout. + std::string cpu(llvm::sys::getHostCPUName()); + std::unique_ptr tm( + target->createTargetMachine(target_triple, cpu, features, {}, {})); + if (!tm) + throw std::runtime_error("Error: could not create the target machine\n"); + + // Set data layout and the target triple to the module. + module.setDataLayout(tm->createDataLayout()); + module.setTargetTriple(target_triple); +} + +/// Creates llvm::TargetMachine with certain CPU features turned on/off. +static std::unique_ptr create_target( + llvm::orc::JITTargetMachineBuilder* tm_builder, + const std::string& features, + int opt_level) { + // First, look up the target. + std::string error_msg; + auto target_triple = tm_builder->getTargetTriple().getTriple(); + auto* target = llvm::TargetRegistry::lookupTarget(target_triple, error_msg); + if (!target) + throw std::runtime_error("Error " + error_msg + "\n"); + + // Create default target machine with provided features. + auto tm = target->createTargetMachine(target_triple, + llvm::sys::getHostCPUName().str(), + features, + tm_builder->getOptions(), + tm_builder->getRelocationModel(), + tm_builder->getCodeModel(), + static_cast(opt_level), + /*JIT=*/true); + if (!tm) + throw std::runtime_error("Error: could not create the target machine\n"); + + return std::unique_ptr(tm); +} + +/****************************************************************************************/ +/* JIT driver */ +/****************************************************************************************/ + +void JITDriver::init(std::string features, + std::vector lib_paths, + BenchmarkInfo* benchmark_info) { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + initialise_optimisation_passes(); + + // Set the target triple and the data layout for the module. + set_triple_and_data_layout(*module, features); + auto data_layout = module->getDataLayout(); + + // If benchmarking, enable listeners to use GDB, perf or VTune. Note that LLVM should be built + // with listeners on (e.g. -DLLVM_USE_PERF=ON). + if (benchmark_info) { + gdb_event_listener = llvm::JITEventListener::createGDBRegistrationListener(); +#if defined(NMODL_HAVE_JIT_EVENT_LISTENERS) + perf_event_listener = llvm::JITEventListener::createPerfJITEventListener(); + intel_event_listener = llvm::JITEventListener::createIntelJITEventListener(); +#endif + } + + // Create object linking function callback. + auto object_linking_layer_creator = [&](llvm::orc::ExecutionSession& session, + const llvm::Triple& triple) { + // Create linking layer. + auto layer = std::make_unique(session, []() { + return std::make_unique(); + }); + + // Register event listeners if they exist. + if (gdb_event_listener) + layer->registerJITEventListener(*gdb_event_listener); + if (perf_event_listener) + layer->registerJITEventListener(*perf_event_listener); + if (intel_event_listener) + layer->registerJITEventListener(*intel_event_listener); + + for (const auto& lib_path: lib_paths) { + // For every library path, create a corresponding memory buffer. + auto memory_buffer = llvm::MemoryBuffer::getFile(lib_path); + if (!memory_buffer) + throw std::runtime_error("Unable to create memory buffer for " + lib_path); + + // Create a new JIT library instance for this session and resolve symbols. + auto& jd = session.createBareJITDylib(std::string(lib_path)); + auto loaded = + llvm::orc::DynamicLibrarySearchGenerator::Load(lib_path.data(), + data_layout.getGlobalPrefix()); + + if (!loaded) + throw std::runtime_error("Unable to load " + lib_path); + jd.addGenerator(std::move(*loaded)); + cantFail(layer->add(jd, std::move(*memory_buffer))); + } + + return layer; + }; + + // Create IR compile function callback. + auto compile_function_creator = [&](llvm::orc::JITTargetMachineBuilder tm_builder) + -> llvm::Expected> { + // Create target machine with some features possibly turned off. + int opt_level_codegen = benchmark_info ? benchmark_info->opt_level_codegen : 0; + auto tm = create_target(&tm_builder, features, opt_level_codegen); + + // Optimise the LLVM IR module and save it to .ll file if benchmarking. + if (benchmark_info) { + optimise_module(*module, benchmark_info->opt_level_ir, tm.get()); + + std::error_code error_code; + std::unique_ptr out = + std::make_unique(benchmark_info->output_dir + "/" + + benchmark_info->filename + "_opt.ll", + error_code, + llvm::sys::fs::OF_Text); + if (error_code) + throw std::runtime_error("Error: " + error_code.message()); + + std::unique_ptr annotator; + module->print(out->os(), annotator.get()); + out->keep(); + } + + return std::make_unique(std::move(tm)); + }; + + // Set the JIT instance. + auto jit_instance = cantFail(llvm::orc::LLJITBuilder() + .setCompileFunctionCreator(compile_function_creator) + .setObjectLinkingLayerCreator(object_linking_layer_creator) + .create()); + + // Add a ThreadSafeModule to the driver. + llvm::orc::ThreadSafeModule tsm(std::move(module), std::make_unique()); + cantFail(jit_instance->addIRModule(std::move(tsm))); + jit = std::move(jit_instance); + + // Resolve symbols. + llvm::orc::JITDylib& sym_tab = jit->getMainJITDylib(); + sym_tab.addGenerator(cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + data_layout.getGlobalPrefix()))); + + // Optionally, dump the binary to the object file. + if (benchmark_info) { + std::string object_file = benchmark_info->filename + ".o"; + if (utils::file_exists(object_file)) { + int status = remove(object_file.c_str()); + if (status) { + throw std::runtime_error("Can not remove object file " + object_file); + } + } + jit->getObjTransformLayer().setTransform( + llvm::orc::DumpObjects(benchmark_info->output_dir, benchmark_info->filename)); + } +} +} // namespace runner +} // namespace nmodl diff --git a/test/benchmark/jit_driver.hpp b/test/benchmark/jit_driver.hpp new file mode 100644 index 0000000000..afb1317cd8 --- /dev/null +++ b/test/benchmark/jit_driver.hpp @@ -0,0 +1,173 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#pragma once + +/** + * \dir + * \brief Implementation of LLVM's JIT-based execution engine to run functions from MOD files + * + * \file + * \brief \copybrief nmodl::runner::JITDriver + */ + +#include "llvm/ExecutionEngine/JITEventListener.h" +#include "llvm/ExecutionEngine/Orc/LLJIT.h" + +namespace nmodl { +namespace runner { + +/// A struct to hold the information for benchmarking. +struct BenchmarkInfo { + /// Object filename to dump. + std::string filename; + + /// Object file output directory. + std::string output_dir; + + /// Optimisation level for generated IR. + int opt_level_ir; + + /// Optimisation level for machine code generation. + int opt_level_codegen; +}; + +/** + * \class JITDriver + * \brief Driver to execute a MOD file function via LLVM IR backend. + */ +class JITDriver { + private: + std::unique_ptr context = std::make_unique(); + + std::unique_ptr jit; + + /// LLVM IR module to execute. + std::unique_ptr module; + + /// GDB event listener. + llvm::JITEventListener* gdb_event_listener = nullptr; + + /// perf event listener. + llvm::JITEventListener* perf_event_listener = nullptr; + + /// Intel event listener. + llvm::JITEventListener* intel_event_listener = nullptr; + + public: + explicit JITDriver(std::unique_ptr m) + : module(std::move(m)) {} + + /// Initializes the JIT driver. + void init(std::string features = "", + std::vector lib_paths = {}, + BenchmarkInfo* benchmark_info = nullptr); + + /// Lookups the entry-point without arguments in the JIT and executes it, returning the result. + template + ReturnType execute_without_arguments(const std::string& entry_point) { + auto expected_symbol = jit->lookup(entry_point); + if (!expected_symbol) + throw std::runtime_error("Error: entry-point symbol not found in JIT\n"); + + auto (*res)() = (ReturnType(*)())(intptr_t) expected_symbol->getAddress(); + ReturnType result = res(); + return result; + } + + /// Lookups the entry-point with an argument in the JIT and executes it, returning the result. + template + ReturnType execute_with_arguments(const std::string& entry_point, ArgType arg) { + auto expected_symbol = jit->lookup(entry_point); + if (!expected_symbol) + throw std::runtime_error("Error: entry-point symbol not found in JIT\n"); + + auto (*res)(ArgType) = (ReturnType(*)(ArgType))(intptr_t) expected_symbol->getAddress(); + ReturnType result = res(arg); + return result; + } +}; + +/** + * \class BaseRunner + * \brief A base runner class that provides functionality to execute an + * entry point in the LLVM IR module. + */ +class BaseRunner { + protected: + std::unique_ptr driver; + + explicit BaseRunner(std::unique_ptr m) + : driver(std::make_unique(std::move(m))) {} + + public: + /// Sets up the JIT driver. + virtual void initialize_driver() = 0; + + /// Runs the entry-point function without arguments. + template + ReturnType run_without_arguments(const std::string& entry_point) { + return driver->template execute_without_arguments(entry_point); + } + + /// Runs the entry-point function with a pointer to the data as an argument. + template + ReturnType run_with_argument(const std::string& entry_point, ArgType arg) { + return driver->template execute_with_arguments(entry_point, arg); + } +}; + +/** + * \class TestRunner + * \brief A simple runner for testing purposes. + */ +class TestRunner: public BaseRunner { + public: + explicit TestRunner(std::unique_ptr m) + : BaseRunner(std::move(m)) {} + + virtual void initialize_driver() { + driver->init(); + } +}; + +/** + * \class BenchmarkRunner + * \brief A runner with benchmarking functionality. It takes user-specified CPU + * features into account, as well as it can link against shared libraries. + */ +class BenchmarkRunner: public BaseRunner { + private: + /// Benchmarking information passed to JIT driver. + BenchmarkInfo benchmark_info; + + /// CPU features specified by the user. + std::string features; + + /// Shared libraries' paths to link against. + std::vector shared_lib_paths; + + public: + BenchmarkRunner(std::unique_ptr m, + std::string filename, + std::string output_dir, + std::string features = "", + std::vector lib_paths = {}, + int opt_level_ir = 0, + int opt_level_codegen = 0) + : BaseRunner(std::move(m)) + , benchmark_info{filename, output_dir, opt_level_ir, opt_level_codegen} + , features(features) + , shared_lib_paths(lib_paths) {} + + virtual void initialize_driver() { + driver->init(features, shared_lib_paths, &benchmark_info); + } +}; + +} // namespace runner +} // namespace nmodl diff --git a/test/benchmark/llvm_benchmark.cpp b/test/benchmark/llvm_benchmark.cpp new file mode 100644 index 0000000000..e48df0d457 --- /dev/null +++ b/test/benchmark/llvm_benchmark.cpp @@ -0,0 +1,152 @@ +/************************************************************************* + * Copyright (C) 2018-2021 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include +#include + +#include "codegen/llvm/codegen_llvm_visitor.hpp" +#include "llvm_benchmark.hpp" +#include "test/benchmark/jit_driver.hpp" +#include "llvm/Support/Host.h" + +#include "test/unit/codegen/codegen_data_helper.hpp" + + +namespace nmodl { +namespace benchmark { + +/// Precision for the timing measurements. +static constexpr int PRECISION = 9; + +/// Get the host CPU features in the format: +/// +feature,+feature,-feature,+feature,... +/// where `+` indicates that the feature is enabled. +static std::vector get_cpu_features() { + std::string cpu(llvm::sys::getHostCPUName()); + + llvm::SubtargetFeatures features; + llvm::StringMap host_features; + if (llvm::sys::getHostCPUFeatures(host_features)) { + for (auto& f: host_features) + features.AddFeature(f.first(), f.second); + } + return features.getFeatures(); +} + + +void LLVMBenchmark::disable(const std::string& feature, std::vector& host_features) { + for (auto& host_feature: host_features) { + if (feature == host_feature.substr(1)) { + host_feature[0] = '-'; + logger->info("{}", host_feature); + return; + } + } +} + +void LLVMBenchmark::run(const std::shared_ptr& node) { + // create functions + generate_llvm(node); + // Finally, run the benchmark and log the measurements. + run_benchmark(node); +} + +void LLVMBenchmark::generate_llvm(const std::shared_ptr& node) { + // First, visit the AST to build the LLVM IR module and wrap the kernel function calls. + auto start = std::chrono::high_resolution_clock::now(); + llvm_visitor.wrap_kernel_functions(); + auto end = std::chrono::high_resolution_clock::now(); + + // Log the time taken to visit the AST and build LLVM IR. + std::chrono::duration diff = end - start; + logger->info("Created LLVM IR module from NMODL AST in {} sec", diff.count()); +} + +void LLVMBenchmark::run_benchmark(const std::shared_ptr& node) { + // Set the codegen data helper and find the kernels. + auto codegen_data = codegen::CodegenDataHelper(node, llvm_visitor.get_instance_struct_ptr()); + std::vector kernel_names; + llvm_visitor.find_kernel_names(kernel_names); + + // Get feature's string and turn them off depending on the backend. + std::vector features = get_cpu_features(); + logger->info("Backend: {}", backend); + if (backend == "avx2") { + // Disable SSE. + logger->info("Disabling features:"); + disable("sse", features); + disable("sse2", features); + disable("sse3", features); + disable("sse4.1", features); + disable("sse4.2", features); + } else if (backend == "sse2") { + // Disable AVX. + logger->info("Disabling features:"); + disable("avx", features); + disable("avx2", features); + } + + std::string features_str = llvm::join(features.begin(), features.end(), ","); + std::unique_ptr m = llvm_visitor.get_module(); + + // Create the benchmark runner and initialize it. + std::string filename = "v" + std::to_string(llvm_visitor.get_vector_width()) + "_" + + mod_filename; + runner::BenchmarkRunner runner(std::move(m), + filename, + output_dir, + features_str, + shared_libs, + opt_level_ir, + opt_level_codegen); + runner.initialize_driver(); + + // Benchmark every kernel. + for (const auto& kernel_name: kernel_names) { + // For every kernel run the benchmark `num_experiments` times. + double time_min = std::numeric_limits::max(); + double time_max = 0.0; + double time_sum = 0.0; + double time_squared_sum = 0.0; + for (int i = 0; i < num_experiments; ++i) { + // Initialise the data. + auto instance_data = codegen_data.create_data(instance_size, /*seed=*/1); + + // Log instance size once. + if (i == 0) { + double size_mbs = instance_data.num_bytes / (1024.0 * 1024.0); + logger->info("Benchmarking kernel '{}' with {} MBs dataset", kernel_name, size_mbs); + } + + // Record the execution time of the kernel. + std::string wrapper_name = "__" + kernel_name + "_wrapper"; + auto start = std::chrono::high_resolution_clock::now(); + runner.run_with_argument(kernel_name, instance_data.base_ptr); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = end - start; + + // Log the time taken for each run. + logger->info("Experiment {} compute time = {:.6f} sec", i, diff.count()); + + // Update statistics. + time_sum += diff.count(); + time_squared_sum += diff.count() * diff.count(); + time_min = std::min(time_min, diff.count()); + time_max = std::max(time_max, diff.count()); + } + // Log the average time taken for the kernel. + double time_mean = time_sum / num_experiments; + logger->info("Average compute time = {:.6f}", time_mean); + logger->info("Compute time variance = {:g}", + time_squared_sum / num_experiments - time_mean * time_mean); + logger->info("Minimum compute time = {:.6f}", time_min); + logger->info("Maximum compute time = {:.6f}\n", time_max); + } +} + +} // namespace benchmark +} // namespace nmodl diff --git a/test/benchmark/llvm_benchmark.hpp b/test/benchmark/llvm_benchmark.hpp new file mode 100644 index 0000000000..9696191172 --- /dev/null +++ b/test/benchmark/llvm_benchmark.hpp @@ -0,0 +1,94 @@ +/************************************************************************* + * Copyright (C) 2018-2021 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#pragma once + +#include + +#include "codegen/llvm/codegen_llvm_visitor.hpp" +#include "utils/logger.hpp" + +namespace nmodl { +namespace benchmark { + +/** + * \class LLVMBenchmark + * \brief A wrapper to execute MOD file kernels via LLVM IR backend, and + * benchmark compile-time and runtime. + */ +class LLVMBenchmark { + private: + /// LLVM visitor. + codegen::CodegenLLVMVisitor& llvm_visitor; + + /// Source MOD file name. + std::string mod_filename; + + /// The output directory for logs and other files. + std::string output_dir; + + /// Paths to shared libraries. + std::vector shared_libs; + + /// The number of experiments to repeat. + int num_experiments; + + /// The size of the instance struct for benchmarking. + int instance_size; + + /// Benchmarking backend + std::string backend; + + /// Optimisation level for LLVM IR transformations. + int opt_level_ir; + + /// Optimisation level for machine code generation. + int opt_level_codegen; + + /// Filestream for dumping logs to the file. + std::ofstream ofs; + + public: + LLVMBenchmark(codegen::CodegenLLVMVisitor& llvm_visitor, + const std::string& mod_filename, + const std::string& output_dir, + std::vector shared_libs, + int num_experiments, + int instance_size, + const std::string& backend, + int opt_level_ir, + int opt_level_codegen) + : llvm_visitor(llvm_visitor) + , mod_filename(mod_filename) + , output_dir(output_dir) + , shared_libs(shared_libs) + , num_experiments(num_experiments) + , instance_size(instance_size) + , backend(backend) + , opt_level_ir(opt_level_ir) + , opt_level_codegen(opt_level_codegen) {} + + /// Runs the benchmark. + void run(const std::shared_ptr& node); + + private: + /// Disables the specified feature in the target. + void disable(const std::string& feature, std::vector& host_features); + + /// Visits the AST to construct the LLVM IR module. + void generate_llvm(const std::shared_ptr& node); + + /// Runs the main body of the benchmark, executing the compute kernels. + void run_benchmark(const std::shared_ptr& node); + + /// Sets the log output stream (file or console). + void set_log_output(); +}; + + +} // namespace benchmark +} // namespace nmodl diff --git a/test/integration/mod/procedure.mod b/test/integration/mod/procedure.mod new file mode 100644 index 0000000000..daa4ad33ad --- /dev/null +++ b/test/integration/mod/procedure.mod @@ -0,0 +1,37 @@ +NEURON { + SUFFIX procedure_test + THREADSAFE +} + +PROCEDURE hello_world() { + printf("Hello World") +} + +PROCEDURE simple_sum(x, y) { + LOCAL z + z = x + y +} + +PROCEDURE complex_sum(v) { + LOCAL alpha, beta, sum + { + alpha = .1 * exp(-(v+40)) + beta = 4 * exp(-(v+65)/18) + sum = alpha + beta + } +} + +PROCEDURE loop_proc(v, t) { + LOCAL i + i = 0 + WHILE(i < 10) { + printf("Hello World") + i = i + 1 + } +} + +FUNCTION square(x) { + LOCAL res + res = x * x + square = res +} diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 04e33614cd..91721010e9 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -60,6 +60,11 @@ add_executable(testcodegen codegen/main.cpp codegen/codegen_ispc.cpp codegen/cod target_link_libraries(testmodtoken lexer util) target_link_libraries(testlexer lexer util) +target_link_libraries(testprinter printer util) +target_link_libraries(testsymtab symtab lexer util) +target_link_libraries(testunitlexer lexer util) +target_link_libraries(testunitparser lexer test_util config) + target_link_libraries( testparser visitor @@ -69,6 +74,7 @@ target_link_libraries( test_util printer ${NMODL_WRAPPER_LIBS}) + target_link_libraries( testvisitor visitor @@ -78,6 +84,7 @@ target_link_libraries( test_util printer ${NMODL_WRAPPER_LIBS}) + target_link_libraries( testcodegen codegen @@ -88,10 +95,44 @@ target_link_libraries( test_util printer ${NMODL_WRAPPER_LIBS}) -target_link_libraries(testprinter printer util) -target_link_libraries(testsymtab symtab lexer util) -target_link_libraries(testunitlexer lexer util) -target_link_libraries(testunitparser lexer test_util config) + +if(NMODL_ENABLE_LLVM) + include_directories(${LLVM_INCLUDE_DIRS} codegen) + + add_library(benchmark_data STATIC codegen/codegen_data_helper.cpp) + add_dependencies(benchmark_data lexer) + + add_executable(testllvm visitor/main.cpp codegen/codegen_llvm_ir.cpp + codegen/codegen_data_helper.cpp codegen/codegen_llvm_instance_struct.cpp) + add_executable(test_llvm_runner visitor/main.cpp codegen/codegen_data_helper.cpp + codegen/codegen_llvm_execution.cpp) + target_link_libraries( + testllvm + llvm_codegen + codegen + visitor + symtab + lexer + util + test_util + printer + ${NMODL_WRAPPER_LIBS} + ${LLVM_LIBS_TO_LINK}) + target_link_libraries( + test_llvm_runner + llvm_codegen + llvm_benchmark + codegen + visitor + symtab + lexer + util + test_util + printer + ${NMODL_WRAPPER_LIBS} + ${LLVM_LIBS_TO_LINK}) + set(CODEGEN_TEST testllvm) +endif() # ============================================================================= # Use catch_discover instead of add_test for granular test report if CMAKE ver is greater than 3.9, @@ -102,7 +143,6 @@ if(NOT LINK_AGAINST_PYTHON) list(APPEND testvisitor_env "NMODL_PYLIB=$ENV{NMODL_PYLIB}") list(APPEND testvisitor_env "NMODL_WRAPLIB=${PROJECT_BINARY_DIR}/lib/nmodl/libpywrapper${CMAKE_SHARED_LIBRARY_SUFFIX}") - endif() foreach( @@ -117,8 +157,8 @@ foreach( testnewton testfast_math testunitlexer - testunitparser) - + testunitparser + ${CODEGEN_TEST}) if(${CMAKE_VERSION} VERSION_GREATER "3.10") if(${test_name} STREQUAL "testvisitor") catch_discover_tests(${test_name} TEST_PREFIX "${test_name}/" PROPERTIES ENVIRONMENT diff --git a/test/unit/codegen/codegen_data_helper.cpp b/test/unit/codegen/codegen_data_helper.cpp new file mode 100644 index 0000000000..a0ee6ec957 --- /dev/null +++ b/test/unit/codegen/codegen_data_helper.cpp @@ -0,0 +1,195 @@ +#include + +#include "ast/codegen_var_type.hpp" +#include "codegen/llvm/codegen_llvm_helper_visitor.hpp" + +#include "codegen_data_helper.hpp" + +namespace nmodl { +namespace codegen { + +// scalar variables with default values +const double default_nthread_dt_value = 0.025; +const double default_nthread_t_value = 100.0; +const double default_celsius_value = 34.0; +const int default_second_order_value = 0; + +// cleanup all members and struct base pointer +CodegenInstanceData::~CodegenInstanceData() { + // first free num_ptr_members members which are pointers + for (size_t i = 0; i < num_ptr_members; i++) { + free(members[i]); + } + // and then pointer to container struct + free(base_ptr); +} + +/** + * \todo : various things can be improved here + * - if variable is voltage then initialization range could be -65 to +65 + * - if variable is double or float then those could be initialize with + * "some" floating point value between range like 1.0 to 100.0. Note + * it would be nice to have unique values to avoid errors like division + * by zero. We have simple implementation that is taking care of this. + * - if variable is integer then initialization range must be between + * 0 and num_elements. In practice, num_elements is number of instances + * of a particular mechanism. This would be <= number of compartments + * in the cell. For now, just initialize integer variables from 0 to + * num_elements - 1. + */ +void initialize_variable(const std::shared_ptr& var, + void* ptr, + size_t initial_value, + size_t num_elements) { + ast::AstNodeType type = var->get_type()->get_type(); + const std::string& name = var->get_name()->get_node_name(); + + if (type == ast::AstNodeType::DOUBLE) { + const auto& generated_double_data = generate_dummy_data(initial_value, + num_elements); + double* data = (double*) ptr; + for (size_t i = 0; i < num_elements; i++) { + data[i] = generated_double_data[i]; + } + } else if (type == ast::AstNodeType::FLOAT) { + const auto& generated_float_data = generate_dummy_data(initial_value, num_elements); + float* data = (float*) ptr; + for (size_t i = 0; i < num_elements; i++) { + data[i] = generated_float_data[i]; + } + } else if (type == ast::AstNodeType::INTEGER) { + const auto& generated_int_data = generate_dummy_data(initial_value, num_elements); + int* data = (int*) ptr; + for (size_t i = 0; i < num_elements; i++) { + data[i] = generated_int_data[i]; + } + } else { + throw std::runtime_error("Unhandled data type during initialize_variable"); + }; +} + +CodegenInstanceData CodegenDataHelper::create_data(size_t num_elements, size_t seed) { + // alignment with 64-byte to generate aligned loads/stores + const unsigned NBYTE_ALIGNMENT = 64; + + // get variable information + const auto& variables = instance->get_codegen_vars(); + + // start building data + CodegenInstanceData data; + data.num_elements = num_elements; + + // base pointer to instance object + void* base = nullptr; + + // max size of each member : pointer / double has maximum size + size_t member_size = std::max(sizeof(double), sizeof(double*)); + + // allocate instance object with memory alignment + posix_memalign(&base, NBYTE_ALIGNMENT, member_size * variables.size()); + data.base_ptr = base; + data.num_bytes += member_size * variables.size(); + + size_t offset = 0; + void* ptr = base; + size_t variable_index = 0; + + // allocate each variable and allocate memory at particular offset in base pointer + for (auto& var: variables) { + // only process until first non-pointer variable + if (!var->get_is_pointer()) { + break; + } + + // check type of variable and it's size + size_t member_size = 0; + ast::AstNodeType type = var->get_type()->get_type(); + if (type == ast::AstNodeType::DOUBLE) { + member_size = sizeof(double); + } else if (type == ast::AstNodeType::FLOAT) { + member_size = sizeof(float); + } else if (type == ast::AstNodeType::INTEGER) { + member_size = sizeof(int); + } + + // allocate memory and setup a pointer + void* member; + posix_memalign(&member, NBYTE_ALIGNMENT, member_size * num_elements); + + // integer values are often offsets so they must start from + // 0 to num_elements-1 to avoid out of bound accesses. + int initial_value = variable_index; + if (type == ast::AstNodeType::INTEGER) { + initial_value = 0; + } + initialize_variable(var, member, initial_value, num_elements); + data.num_bytes += member_size * num_elements; + + // copy address at specific location in the struct + memcpy(ptr, &member, sizeof(double*)); + + data.offsets.push_back(offset); + data.members.push_back(member); + data.num_ptr_members++; + + // all pointer types are of same size, so just use double* + offset += sizeof(double*); + ptr = (char*) base + offset; + + variable_index++; + } + + // we are now switching from pointer type to next member type (e.g. double) + // ideally we should use padding but switching from double* to double should + // already meet alignment requirements + for (auto& var: variables) { + // process only scalar elements + if (var->get_is_pointer()) { + continue; + } + ast::AstNodeType type = var->get_type()->get_type(); + const std::string& name = var->get_name()->get_node_name(); + + // some default values for standard parameters + double value = 0; + if (name == naming::NTHREAD_DT_VARIABLE) { + value = default_nthread_dt_value; + } else if (name == naming::NTHREAD_T_VARIABLE) { + value = default_nthread_t_value; + } else if (name == naming::CELSIUS_VARIABLE) { + value = default_celsius_value; + } else if (name == CodegenLLVMHelperVisitor::NODECOUNT_VAR) { + value = num_elements; + } else if (name == naming::SECOND_ORDER_VARIABLE) { + value = default_second_order_value; + } + + if (type == ast::AstNodeType::DOUBLE) { + *((double*) ptr) = value; + data.offsets.push_back(offset); + data.members.push_back(ptr); + offset += sizeof(double); + ptr = (char*) base + offset; + } else if (type == ast::AstNodeType::FLOAT) { + *((float*) ptr) = float(value); + data.offsets.push_back(offset); + data.members.push_back(ptr); + offset += sizeof(float); + ptr = (char*) base + offset; + } else if (type == ast::AstNodeType::INTEGER) { + *((int*) ptr) = int(value); + data.offsets.push_back(offset); + data.members.push_back(ptr); + offset += sizeof(int); + ptr = (char*) base + offset; + } else { + throw std::runtime_error( + "Unhandled type while allocating data in CodegenDataHelper::create_data()"); + } + } + + return data; +} + +} // namespace codegen +} // namespace nmodl diff --git a/test/unit/codegen/codegen_data_helper.hpp b/test/unit/codegen/codegen_data_helper.hpp new file mode 100644 index 0000000000..76c4f422d9 --- /dev/null +++ b/test/unit/codegen/codegen_data_helper.hpp @@ -0,0 +1,113 @@ +/************************************************************************* + * Copyright (C) 2018-2021 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#pragma once + +#include + +#include "ast/ast.hpp" + +/// \file +/// \brief Generate test data for testing and benchmarking compute kernels + +namespace nmodl { +namespace codegen { + +/// common scalar variables +extern const double default_nthread_dt_value; +extern const double default_nthread_t_value; +extern const double default_celsius_value; +extern const int default_second_order_value; + +/** + * \class CodegenInstanceData + * \brief Wrapper class to pack data allocate for instance + */ +struct CodegenInstanceData { + /// base pointer which can be type casted + /// to instance struct at run time + void* base_ptr = nullptr; + + /// length of each member of pointer type + size_t num_elements = 0; + + /// number of pointer members + size_t num_ptr_members = 0; + + /// offset relative to base_ptr to locate + /// each member variable in instance struct + std::vector offsets; + + /// pointer to array allocated for each member variable + /// i.e. *(base_ptr + offsets[0]) will be members[0] + std::vector members; + + /// size in bytes + size_t num_bytes = 0; + + // cleanup all memory allocated for type and member variables + ~CodegenInstanceData(); +}; + + +/** + * Generate vector of dummy data according to the template type specified + * + * For double or float type: generate vector starting from `initial_value` + * with an increment of 1e-5. The increment can be any other + * value but 1e-5 is chosen because when we benchmark with + * a million elements then the values are in the range of + * . + * For int type: generate vector starting from initial_value with an + * increments of 1 + * + * \param inital_value Base value for initializing the data + * \param num_elements Number of element of the generated vector + * \return std::vector of dummy data for testing purposes + */ +template +std::vector generate_dummy_data(size_t initial_value, size_t num_elements) { + std::vector data(num_elements); + T increment; + if (std::is_same::value) { + increment = 1; + } else { + increment = 1e-5; + } + for (size_t i = 0; i < num_elements; i++) { + data[i] = initial_value + increment * i; + } + return data; +} + +/** + * \class CodegenDataHelper + * \brief Helper to allocate and initialize data for benchmarking + * + * The `ast::InstanceStruct` is has different number of member + * variables for different MOD files and hence we can't instantiate + * it at compile time. This class helps to inspect the variables + * information gathered from AST and allocate memory block that + * can be type cast to the `ast::InstanceStruct` corresponding + * to the MOD file. + */ +class CodegenDataHelper { + std::shared_ptr program; + std::shared_ptr instance; + + public: + CodegenDataHelper() = delete; + CodegenDataHelper(const std::shared_ptr& program, + const std::shared_ptr& instance) + : program(program) + , instance(instance) {} + + CodegenInstanceData create_data(size_t num_elements, size_t seed); +}; + +} // namespace codegen +} // namespace nmodl diff --git a/test/unit/codegen/codegen_llvm_execution.cpp b/test/unit/codegen/codegen_llvm_execution.cpp new file mode 100644 index 0000000000..aa77a4e493 --- /dev/null +++ b/test/unit/codegen/codegen_llvm_execution.cpp @@ -0,0 +1,611 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include + +#include "ast/program.hpp" +#include "codegen/llvm/codegen_llvm_visitor.hpp" +#include "codegen_data_helper.hpp" +#include "parser/nmodl_driver.hpp" +#include "test/benchmark/jit_driver.hpp" +#include "visitors/checkparent_visitor.hpp" +#include "visitors/neuron_solve_visitor.hpp" +#include "visitors/solve_block_visitor.hpp" +#include "visitors/symtab_visitor.hpp" + +using namespace nmodl; +using namespace runner; +using namespace visitor; +using nmodl::parser::NmodlDriver; + +static double EPSILON = 1e-15; + +//============================================================================= +// Utilities for testing. +//============================================================================= + +struct InstanceTestInfo { + codegen::CodegenInstanceData* instance; + codegen::InstanceVarHelper helper; + int num_elements; +}; + +template +bool check_instance_variable(InstanceTestInfo& instance_info, + std::vector& expected, + const std::string& variable_name) { + std::vector actual; + int variable_index = instance_info.helper.get_variable_index(variable_name); + actual.assign(static_cast(instance_info.instance->members[variable_index]), + static_cast(instance_info.instance->members[variable_index]) + + instance_info.num_elements); + + // While we are comparing double types as well, for simplicity the test cases are hand-crafted + // so that no floating-point arithmetic is really involved. + return actual == expected; +} + +template +void initialise_instance_variable(InstanceTestInfo& instance_info, + std::vector& data, + const std::string& variable_name) { + int variable_index = instance_info.helper.get_variable_index(variable_name); + T* data_start = static_cast(instance_info.instance->members[variable_index]); + for (int i = 0; i < instance_info.num_elements; ++i) + *(data_start + i) = data[i]; +} + +//============================================================================= +// Simple functions: no optimisations +//============================================================================= + +SCENARIO("Arithmetic expression", "[llvm][runner]") { + GIVEN("Functions with some arithmetic expressions") { + std::string nmodl_text = R"( + FUNCTION exponential() { + LOCAL i + i = 1 + exponential = exp(i) + } + + FUNCTION constant() { + constant = 10 + } + + FUNCTION arithmetic() { + LOCAL x, y + x = 3 + y = 7 + arithmetic = x * y / (x + y) + } + + FUNCTION bar() { + LOCAL i, j + i = 2 + j = i + 2 + bar = 2 * 3 + j + } + + FUNCTION function_call() { + foo() + function_call = bar() / constant() + } + + PROCEDURE foo() {} + + FUNCTION with_argument(x) { + with_argument = x + } + + FUNCTION loop() { + LOCAL i, j, sum, result + result = 0 + j = 0 + WHILE (j < 2) { + i = 0 + sum = 0 + WHILE (i < 10) { + sum = sum + i + i = i + 1 + } + j = j + 1 + result = result + sum + } + loop = result + } + )"; + + + NmodlDriver driver; + const auto& ast = driver.parse_string(nmodl_text); + + SymtabVisitor().visit_program(*ast); + codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"unknown", + /*output_dir=*/".", + /*opt_passes=*/false); + llvm_visitor.visit_program(*ast); + + std::unique_ptr m = llvm_visitor.get_module(); + TestRunner runner(std::move(m)); + runner.initialize_driver(); + + THEN("functions are evaluated correctly") { + auto exp_result = runner.run_without_arguments("exponential"); + REQUIRE(fabs(exp_result - 2.718281828459045) < EPSILON); + + auto constant_result = runner.run_without_arguments("constant"); + REQUIRE(fabs(constant_result - 10.0) < EPSILON); + + auto arithmetic_result = runner.run_without_arguments("arithmetic"); + REQUIRE(fabs(arithmetic_result - 2.1) < EPSILON); + + auto function_call_result = runner.run_without_arguments("function_call"); + REQUIRE(fabs(function_call_result - 1.0) < EPSILON); + + double data = 10.0; + auto with_argument_result = runner.run_with_argument("with_argument", + data); + REQUIRE(fabs(with_argument_result - 10.0) < EPSILON); + + auto loop_result = runner.run_without_arguments("loop"); + REQUIRE(fabs(loop_result - 90.0) < EPSILON); + } + } +} + +//============================================================================= +// Simple functions: with optimisations +//============================================================================= + +SCENARIO("Optimised arithmetic expression", "[llvm][runner]") { + GIVEN("Functions with some arithmetic expressions") { + std::string nmodl_text = R"( + FUNCTION exponential() { + LOCAL i + i = 1 + exponential = exp(i) + } + + FUNCTION constant() { + constant = 10 * 2 - 100 / 50 * 5 + } + + FUNCTION arithmetic() { + LOCAL x, y + x = 3 + y = 7 + arithmetic = x * y / (x + y) + } + + FUNCTION conditionals() { + LOCAL x, y, z + x = 100 + y = -100 + z = 0 + if (x == 200) { + conditionals = 1 + } else if (x == 400) { + conditionals = 2 + } else if (x == 100) { + if (y == -100 && z != 0) { + conditionals = 3 + } else { + if (y < -99 && z == 0) { + conditionals = 4 + } else { + conditionals = 5 + } + } + } else { + conditionals = 6 + } + } + + FUNCTION bar() { + LOCAL i, j + i = 2 + j = i + 2 + bar = 2 * 3 + j + } + + FUNCTION function_call() { + foo() + function_call = bar() / constant() + } + + PROCEDURE foo() {} + + )"; + + + NmodlDriver driver; + const auto& ast = driver.parse_string(nmodl_text); + + SymtabVisitor().visit_program(*ast); + codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"unknown", + /*output_dir=*/".", + /*opt_passes=*/true); + llvm_visitor.visit_program(*ast); + + std::unique_ptr m = llvm_visitor.get_module(); + TestRunner runner(std::move(m)); + runner.initialize_driver(); + + THEN("optimizations preserve function results") { + // Check exponential is turned into a constant. + auto exp_result = runner.run_without_arguments("exponential"); + REQUIRE(fabs(exp_result - 2.718281828459045) < EPSILON); + + // Check constant folding. + auto constant_result = runner.run_without_arguments("constant"); + REQUIRE(fabs(constant_result - 10.0) < EPSILON); + + // Check nested conditionals + auto conditionals_result = runner.run_without_arguments("conditionals"); + REQUIRE(fabs(conditionals_result - 4.0) < EPSILON); + + // Check constant folding. + auto arithmetic_result = runner.run_without_arguments("arithmetic"); + REQUIRE(fabs(arithmetic_result - 2.1) < EPSILON); + + auto function_call_result = runner.run_without_arguments("function_call"); + REQUIRE(fabs(function_call_result - 1.0) < EPSILON); + } + } +} + +//============================================================================= +// State scalar kernel. +//============================================================================= + +SCENARIO("Simple scalar kernel", "[llvm][runner]") { + GIVEN("Simple MOD file with a state update") { + std::string nmodl_text = R"( + NEURON { + SUFFIX test + NONSPECIFIC_CURRENT i + RANGE x0, x1 + } + + STATE { + x + } + + ASSIGNED { + v + x0 + x1 + } + + BREAKPOINT { + SOLVE states METHOD cnexp + i = 0 + } + + DERIVATIVE states { + x = (x0 - x) / x1 + } + )"; + + + NmodlDriver driver; + const auto& ast = driver.parse_string(nmodl_text); + + // Run passes on the AST to generate LLVM. + SymtabVisitor().visit_program(*ast); + NeuronSolveVisitor().visit_program(*ast); + SolveBlockVisitor().visit_program(*ast); + codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"unknown", + /*output_dir=*/".", + /*opt_passes=*/false, + /*use_single_precision=*/false, + /*vector_width=*/1); + llvm_visitor.visit_program(*ast); + llvm_visitor.wrap_kernel_functions(); + + // Create the instance struct data. + int num_elements = 4; + const auto& generated_instance_struct = llvm_visitor.get_instance_struct_ptr(); + auto codegen_data = codegen::CodegenDataHelper(ast, generated_instance_struct); + auto instance_data = codegen_data.create_data(num_elements, /*seed=*/1); + + // Fill the instance struct data with some values. + std::vector x = {1.0, 2.0, 3.0, 4.0}; + std::vector x0 = {5.0, 5.0, 5.0, 5.0}; + std::vector x1 = {1.0, 1.0, 1.0, 1.0}; + + InstanceTestInfo instance_info{&instance_data, + llvm_visitor.get_instance_var_helper(), + num_elements}; + initialise_instance_variable(instance_info, x, "x"); + initialise_instance_variable(instance_info, x0, "x0"); + initialise_instance_variable(instance_info, x1, "x1"); + + // Set up the JIT runner. + std::unique_ptr module = llvm_visitor.get_module(); + TestRunner runner(std::move(module)); + runner.initialize_driver(); + + THEN("Values in struct have changed according to the formula") { + runner.run_with_argument("__nrn_state_test_wrapper", + instance_data.base_ptr); + std::vector x_expected = {4.0, 3.0, 2.0, 1.0}; + REQUIRE(check_instance_variable(instance_info, x_expected, "x")); + } + } +} + +//============================================================================= +// State vectorised kernel with optimisations on. +//============================================================================= + +SCENARIO("Simple vectorised kernel", "[llvm][runner]") { + GIVEN("Simple MOD file with a state update") { + std::string nmodl_text = R"( + NEURON { + SUFFIX test + NONSPECIFIC_CURRENT i + RANGE x0, x1 + } + + STATE { + x y + } + + ASSIGNED { + v + x0 + x1 + } + + BREAKPOINT { + SOLVE states METHOD cnexp + i = 0 + } + + DERIVATIVE states { + x = (x0 - x) / x1 + y = v + } + )"; + + + NmodlDriver driver; + const auto& ast = driver.parse_string(nmodl_text); + + // Run passes on the AST to generate LLVM. + SymtabVisitor().visit_program(*ast); + NeuronSolveVisitor().visit_program(*ast); + SolveBlockVisitor().visit_program(*ast); + codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"unknown", + /*output_dir=*/".", + /*opt_passes=*/true, + /*use_single_precision=*/false, + /*vector_width=*/4); + llvm_visitor.visit_program(*ast); + llvm_visitor.wrap_kernel_functions(); + + // Create the instance struct data. + int num_elements = 10; + const auto& generated_instance_struct = llvm_visitor.get_instance_struct_ptr(); + auto codegen_data = codegen::CodegenDataHelper(ast, generated_instance_struct); + auto instance_data = codegen_data.create_data(num_elements, /*seed=*/1); + + // Fill the instance struct data with some values for unit testing. + std::vector x = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}; + std::vector x0 = {11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0}; + std::vector x1 = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + + std::vector voltage = {3.0, 4.0, 7.0, 1.0, 2.0, 5.0, 8.0, 6.0, 10.0, 9.0}; + std::vector node_index = {3, 4, 0, 1, 5, 7, 2, 6, 9, 8}; + + InstanceTestInfo instance_info{&instance_data, + llvm_visitor.get_instance_var_helper(), + num_elements}; + initialise_instance_variable(instance_info, x, "x"); + initialise_instance_variable(instance_info, x0, "x0"); + initialise_instance_variable(instance_info, x1, "x1"); + + initialise_instance_variable(instance_info, voltage, "voltage"); + initialise_instance_variable(instance_info, node_index, "node_index"); + + // Set up the JIT runner. + std::unique_ptr module = llvm_visitor.get_module(); + TestRunner runner(std::move(module)); + runner.initialize_driver(); + + THEN("Values in struct have changed according to the formula") { + runner.run_with_argument("__nrn_state_test_wrapper", + instance_data.base_ptr); + // Check that the main and remainder loops correctly change the data stored in x. + std::vector x_expected = {10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0}; + REQUIRE(check_instance_variable(instance_info, x_expected, "x")); + + // Check that the gather load produces correct results in y: + // y[id] = voltage[node_index[id]] + std::vector y_expected = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}; + REQUIRE(check_instance_variable(instance_info, y_expected, "y")); + } + } +} + +//============================================================================= +// Vectorised kernel with ion writes. +//============================================================================= + +SCENARIO("Vectorised kernel with scatter instruction", "[llvm][runner]") { + GIVEN("Simple MOD file with ion writes") { + std::string nmodl_text = R"( + NEURON { + SUFFIX test + USEION ca WRITE cai + } + + BREAKPOINT { + SOLVE states METHOD cnexp + } + + DERIVATIVE states { + : increment cai to test scatter + cai = cai + 1 + } + )"; + + + NmodlDriver driver; + const auto& ast = driver.parse_string(nmodl_text); + + // Run passes on the AST to generate LLVM. + SymtabVisitor().visit_program(*ast); + NeuronSolveVisitor().visit_program(*ast); + SolveBlockVisitor().visit_program(*ast); + codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"unknown", + /*output_dir=*/".", + /*opt_passes=*/false, + /*use_single_precision=*/false, + /*vector_width=*/2); + llvm_visitor.visit_program(*ast); + llvm_visitor.wrap_kernel_functions(); + + // Create the instance struct data. + int num_elements = 5; + const auto& generated_instance_struct = llvm_visitor.get_instance_struct_ptr(); + auto codegen_data = codegen::CodegenDataHelper(ast, generated_instance_struct); + auto instance_data = codegen_data.create_data(num_elements, /*seed=*/1); + + // Fill the instance struct data with some values. + std::vector cai = {1.0, 2.0, 3.0, 4.0, 5.0}; + std::vector ion_cai = {1.0, 2.0, 3.0, 4.0, 5.0}; + std::vector ion_cai_index = {4, 2, 3, 0, 1}; + + InstanceTestInfo instance_info{&instance_data, + llvm_visitor.get_instance_var_helper(), + num_elements}; + initialise_instance_variable(instance_info, cai, "cai"); + initialise_instance_variable(instance_info, ion_cai, "ion_cai"); + initialise_instance_variable(instance_info, ion_cai_index, "ion_cai_index"); + + // Set up the JIT runner. + std::unique_ptr module = llvm_visitor.get_module(); + TestRunner runner(std::move(module)); + runner.initialize_driver(); + + THEN("Ion values in struct have been updated correctly") { + runner.run_with_argument("__nrn_state_test_wrapper", + instance_data.base_ptr); + // cai[id] = ion_cai[ion_cai_index[id]] + // cai[id] += 1 + std::vector cai_expected = {6.0, 4.0, 5.0, 2.0, 3.0}; + REQUIRE(check_instance_variable(instance_info, cai_expected, "cai")); + + // ion_cai[ion_cai_index[id]] = cai[id] + std::vector ion_cai_expected = {2.0, 3.0, 4.0, 5.0, 6.0}; + REQUIRE(check_instance_variable(instance_info, ion_cai_expected, "ion_cai")); + } + } +} + +//============================================================================= +// Vectorised kernel with control flow. +//============================================================================= + +SCENARIO("Vectorised kernel with simple control flow", "[llvm][runner]") { + GIVEN("Simple MOD file with if statement") { + std::string nmodl_text = R"( + NEURON { + SUFFIX test + } + + STATE { + w x y z + } + + BREAKPOINT { + SOLVE states METHOD cnexp + } + + DERIVATIVE states { + IF (v > 0) { + w = v * w + } + + IF (x < 0) { + x = 7 + } + + IF (0 <= y && y < 10 || z == 0) { + y = 2 * y + } ELSE { + z = z - y + } + + } + )"; + + + NmodlDriver driver; + const auto& ast = driver.parse_string(nmodl_text); + + // Run passes on the AST to generate LLVM. + SymtabVisitor().visit_program(*ast); + NeuronSolveVisitor().visit_program(*ast); + SolveBlockVisitor().visit_program(*ast); + codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"unknown", + /*output_dir=*/".", + /*opt_passes=*/false, + /*use_single_precision=*/false, + /*vector_width=*/2); + llvm_visitor.visit_program(*ast); + llvm_visitor.wrap_kernel_functions(); + + // Create the instance struct data. + int num_elements = 5; + const auto& generated_instance_struct = llvm_visitor.get_instance_struct_ptr(); + auto codegen_data = codegen::CodegenDataHelper(ast, generated_instance_struct); + auto instance_data = codegen_data.create_data(num_elements, /*seed=*/1); + + // Fill the instance struct data with some values. + std::vector x = {-1.0, 2.0, -3.0, 4.0, -5.0}; + std::vector y = {11.0, 2.0, -3.0, 4.0, 100.0}; + std::vector z = {0.0, 1.0, 20.0, 0.0, 40.0}; + + std::vector w = {10.0, 20.0, 30.0, 40.0, 50.0}; + std::vector voltage = {-1.0, 2.0, -1.0, 2.0, -1.0}; + std::vector node_index = {1, 2, 3, 4, 0}; + + InstanceTestInfo instance_info{&instance_data, + llvm_visitor.get_instance_var_helper(), + num_elements}; + initialise_instance_variable(instance_info, w, "w"); + initialise_instance_variable(instance_info, voltage, "voltage"); + initialise_instance_variable(instance_info, node_index, "node_index"); + + initialise_instance_variable(instance_info, x, "x"); + initialise_instance_variable(instance_info, y, "y"); + initialise_instance_variable(instance_info, z, "z"); + + // Set up the JIT runner. + std::unique_ptr module = llvm_visitor.get_module(); + TestRunner runner(std::move(module)); + runner.initialize_driver(); + + THEN("Masked instructions are generated") { + runner.run_with_argument("__nrn_state_test_wrapper", + instance_data.base_ptr); + std::vector w_expected = {20.0, 20.0, 60.0, 40.0, 50.0}; + REQUIRE(check_instance_variable(instance_info, w_expected, "w")); + + std::vector x_expected = {7.0, 2.0, 7.0, 4.0, 7.0}; + REQUIRE(check_instance_variable(instance_info, x_expected, "x")); + + std::vector y_expected = {22.0, 4.0, -3.0, 8.0, 100.0}; + std::vector z_expected = {0.0, 1.0, 23.0, 0.0, -60.0}; + REQUIRE(check_instance_variable(instance_info, y_expected, "y")); + REQUIRE(check_instance_variable(instance_info, z_expected, "z")); + } + } +} diff --git a/test/unit/codegen/codegen_llvm_instance_struct.cpp b/test/unit/codegen/codegen_llvm_instance_struct.cpp new file mode 100644 index 0000000000..e77b6844ae --- /dev/null +++ b/test/unit/codegen/codegen_llvm_instance_struct.cpp @@ -0,0 +1,177 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include + +#include "ast/all.hpp" +#include "ast/program.hpp" +#include "codegen/llvm/codegen_llvm_visitor.hpp" +#include "codegen_data_helper.hpp" +#include "parser/nmodl_driver.hpp" +#include "visitors/checkparent_visitor.hpp" +#include "visitors/neuron_solve_visitor.hpp" +#include "visitors/solve_block_visitor.hpp" +#include "visitors/symtab_visitor.hpp" + +using namespace nmodl; +using namespace codegen; +using namespace visitor; +using nmodl::parser::NmodlDriver; + +//============================================================================= +// Utility to get initialized Struct Instance data +//============================================================================= + +codegen::CodegenInstanceData generate_instance_data(const std::string& text, + bool opt = false, + bool use_single_precision = false, + int vector_width = 1, + size_t num_elements = 100, + size_t seed = 1) { + NmodlDriver driver; + const auto& ast = driver.parse_string(text); + + // Generate full AST and solve the BREAKPOINT block to be able to generate the Instance Struct + SymtabVisitor().visit_program(*ast); + NeuronSolveVisitor().visit_program(*ast); + + codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"test", + /*output_dir=*/".", + opt, + use_single_precision, + vector_width); + llvm_visitor.visit_program(*ast); + llvm_visitor.dump_module(); + const auto& generated_instance_struct = llvm_visitor.get_instance_struct_ptr(); + auto codegen_data = codegen::CodegenDataHelper(ast, generated_instance_struct); + auto instance_data = codegen_data.create_data(num_elements, seed); + return instance_data; +} + +template +bool compare(void* instance_struct_data_ptr, const std::vector& generated_data) { + std::vector instance_struct_vector; + std::cout << "Generated data size: " << generated_data.size() << std::endl; + instance_struct_vector.assign(static_cast(instance_struct_data_ptr), + static_cast(instance_struct_data_ptr) + + generated_data.size()); + for (auto value: instance_struct_vector) { + std::cout << value << std::endl; + } + return instance_struct_vector == generated_data; +} + +//============================================================================= +// Simple Instance Struct creation +//============================================================================= + +SCENARIO("Instance Struct creation", "[visitor][llvm][instance_struct]") { + GIVEN("Instantiate simple Instance Struct") { + std::string nmodl_text = R"( + NEURON { + SUFFIX test + USEION na READ ena + RANGE minf, mtau + } + + STATE { + m + } + + ASSIGNED { + v (mV) + celsius (degC) + ena (mV) + minf + mtau + } + + BREAKPOINT { + SOLVE states METHOD cnexp + } + + DERIVATIVE states { + m' = (minf-m)/mtau + } + )"; + + + THEN("instance struct elements are properly initialized") { + const size_t num_elements = 10; + constexpr static double seed = 42; + auto instance_data = generate_instance_data(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/true, + /*vector_width*/ 1, + num_elements, + seed); + size_t minf_index = 0; + size_t mtau_index = 1; + size_t m_index = 2; + size_t Dm_index = 3; + size_t ena_index = 4; + size_t v_unused_index = 5; + size_t g_unused_index = 6; + size_t ion_ena_index = 7; + size_t ion_ena_index_index = 8; + size_t voltage_index = 9; + size_t node_index_index = 10; + size_t t_index = 11; + size_t dt_index = 12; + size_t celsius_index = 13; + size_t secondorder_index = 14; + size_t node_count_index = 15; + // Check if the various instance struct fields are properly initialized + REQUIRE(compare(instance_data.members[minf_index], + generate_dummy_data(minf_index, num_elements))); + REQUIRE(compare(instance_data.members[ena_index], + generate_dummy_data(ena_index, num_elements))); + REQUIRE(compare(instance_data.members[ion_ena_index], + generate_dummy_data(ion_ena_index, num_elements))); + // index variables are offsets, they start from 0 + REQUIRE(compare(instance_data.members[ion_ena_index_index], + generate_dummy_data(0, num_elements))); + REQUIRE(compare(instance_data.members[node_index_index], + generate_dummy_data(0, num_elements))); + + REQUIRE(*static_cast(instance_data.members[t_index]) == + default_nthread_t_value); + REQUIRE(*static_cast(instance_data.members[node_count_index]) == num_elements); + + // Hard code TestInstanceType struct + struct TestInstanceType { + double* minf; + double* mtau; + double* m; + double* Dm; + double* ena; + double* v_unused; + double* g_unused; + double* ion_ena; + int* ion_ena_index; + double* voltage; + int* node_index; + double t; + double dt; + double celsius; + int secondorder; + int node_count; + }; + // Test if TestInstanceType struct is properly initialized + // Cast void ptr instance_data.base_ptr to TestInstanceType* + TestInstanceType* instance = (TestInstanceType*) instance_data.base_ptr; + REQUIRE(compare(instance->minf, generate_dummy_data(minf_index, num_elements))); + REQUIRE(compare(instance->ena, generate_dummy_data(ena_index, num_elements))); + REQUIRE(compare(instance->ion_ena, + generate_dummy_data(ion_ena_index, num_elements))); + REQUIRE(compare(instance->node_index, generate_dummy_data(0, num_elements))); + REQUIRE(instance->t == default_nthread_t_value); + REQUIRE(instance->celsius == default_celsius_value); + REQUIRE(instance->secondorder == default_second_order_value); + } + } +} diff --git a/test/unit/codegen/codegen_llvm_ir.cpp b/test/unit/codegen/codegen_llvm_ir.cpp new file mode 100644 index 0000000000..fa0a649f2d --- /dev/null +++ b/test/unit/codegen/codegen_llvm_ir.cpp @@ -0,0 +1,1527 @@ +/************************************************************************* + * Copyright (C) 2018-2020 Blue Brain Project + * + * This file is part of NMODL distributed under the terms of the GNU + * Lesser General Public License. See top-level LICENSE file for details. + *************************************************************************/ + +#include +#include + +#include "test/unit/utils/test_utils.hpp" + +#include "ast/program.hpp" +#include "ast/statement_block.hpp" +#include "codegen/llvm/codegen_llvm_helper_visitor.hpp" +#include "codegen/llvm/codegen_llvm_visitor.hpp" +#include "parser/nmodl_driver.hpp" +#include "visitors/checkparent_visitor.hpp" +#include "visitors/inline_visitor.hpp" +#include "visitors/neuron_solve_visitor.hpp" +#include "visitors/solve_block_visitor.hpp" +#include "visitors/symtab_visitor.hpp" +#include "visitors/visitor_utils.hpp" + +using namespace nmodl; +using namespace codegen; +using namespace visitor; + +using namespace test_utils; + +using nmodl::parser::NmodlDriver; + +//============================================================================= +// Utility to get LLVM module as a string +//============================================================================= + +std::string run_llvm_visitor(const std::string& text, + bool opt = false, + bool use_single_precision = false, + int vector_width = 1, + std::string vec_lib = "none", + std::vector fast_math_flags = {}, + bool nmodl_inline = false) { + NmodlDriver driver; + const auto& ast = driver.parse_string(text); + + SymtabVisitor().visit_program(*ast); + if (nmodl_inline) { + InlineVisitor().visit_program(*ast); + } + NeuronSolveVisitor().visit_program(*ast); + SolveBlockVisitor().visit_program(*ast); + + codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"unknown", + /*output_dir=*/".", + opt, + use_single_precision, + vector_width, + vec_lib, + /*add_debug_information=*/false, + fast_math_flags); + + llvm_visitor.visit_program(*ast); + return llvm_visitor.dump_module(); +} + +//============================================================================= +// Utility to get specific NMODL AST nodes +//============================================================================= + +std::vector> run_llvm_visitor_helper( + const std::string& text, + int vector_width, + const std::vector& nodes_to_collect) { + NmodlDriver driver; + const auto& ast = driver.parse_string(text); + + SymtabVisitor().visit_program(*ast); + SolveBlockVisitor().visit_program(*ast); + CodegenLLVMHelperVisitor(vector_width).visit_program(*ast); + + const auto& nodes = collect_nodes(*ast, nodes_to_collect); + + return nodes; +} + +//============================================================================= +// BinaryExpression and Double +//============================================================================= + +SCENARIO("Binary expression", "[visitor][llvm]") { + GIVEN("Procedure with addition of its arguments") { + std::string nmodl_text = R"( + PROCEDURE add(a, b) { + LOCAL i + i = a + b + } + )"; + + THEN("variables are loaded and add instruction is created") { + std::string module_string = + run_llvm_visitor(nmodl_text, /*opt=*/false, /*use_single_precision=*/true); + std::smatch m; + + std::regex rhs(R"(%1 = load float, float\* %b)"); + std::regex lhs(R"(%2 = load float, float\* %a)"); + std::regex res(R"(%3 = fadd float %2, %1)"); + + // Check the float values are loaded correctly and added. + REQUIRE(std::regex_search(module_string, m, rhs)); + REQUIRE(std::regex_search(module_string, m, lhs)); + REQUIRE(std::regex_search(module_string, m, res)); + } + } + + GIVEN("Procedure with multiple binary operators") { + std::string nmodl_text = R"( + PROCEDURE multiple(a, b) { + LOCAL i + i = (a - b) / (a + b) + } + )"; + + THEN("variables are processed from rhs first") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check rhs. + std::regex rr(R"(%1 = load double, double\* %b)"); + std::regex rl(R"(%2 = load double, double\* %a)"); + std::regex x(R"(%3 = fadd double %2, %1)"); + REQUIRE(std::regex_search(module_string, m, rr)); + REQUIRE(std::regex_search(module_string, m, rl)); + REQUIRE(std::regex_search(module_string, m, x)); + + // Check lhs. + std::regex lr(R"(%4 = load double, double\* %b)"); + std::regex ll(R"(%5 = load double, double\* %a)"); + std::regex y(R"(%6 = fsub double %5, %4)"); + REQUIRE(std::regex_search(module_string, m, lr)); + REQUIRE(std::regex_search(module_string, m, ll)); + REQUIRE(std::regex_search(module_string, m, y)); + + // Check result. + std::regex res(R"(%7 = fdiv double %6, %3)"); + REQUIRE(std::regex_search(module_string, m, res)); + } + } + + GIVEN("Procedure with assignment") { + std::string nmodl_text = R"( + PROCEDURE assignment() { + LOCAL i + i = 2 + } + )"; + + THEN("double constant is stored into i") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check store immediate is created. + std::regex allocation(R"(%i = alloca double)"); + std::regex assignment(R"(store double 2.0*e\+00, double\* %i)"); + REQUIRE(std::regex_search(module_string, m, allocation)); + REQUIRE(std::regex_search(module_string, m, assignment)); + } + } + + GIVEN("Function with power operator") { + std::string nmodl_text = R"( + FUNCTION power() { + LOCAL i, j + i = 2 + j = 4 + power = i ^ j + } + )"; + + THEN("'pow' intrinsic is created") { + std::string module_string = + run_llvm_visitor(nmodl_text, /*opt=*/false, /*use_single_precision=*/true); + std::smatch m; + + // Check 'pow' intrinsic. + std::regex declaration(R"(declare float @llvm\.pow\.f32\(float, float\))"); + std::regex pow(R"(call float @llvm\.pow\.f32\(float %.*, float %.*\))"); + REQUIRE(std::regex_search(module_string, m, declaration)); + REQUIRE(std::regex_search(module_string, m, pow)); + } + } +} + +//============================================================================= +// Define +//============================================================================= + +SCENARIO("Define", "[visitor][llvm]") { + GIVEN("Procedure with array variable of length specified by DEFINE") { + std::string nmodl_text = R"( + DEFINE N 100 + + PROCEDURE foo() { + LOCAL x[N] + } + )"; + + THEN("macro is expanded and array is allocated") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check stack allocations for i and j + std::regex array(R"(%x = alloca \[100 x double\])"); + REQUIRE(std::regex_search(module_string, m, array)); + } + } +} + +//============================================================================= +// If/Else statements and comparison operators +//============================================================================= + +SCENARIO("Comparison", "[visitor][llvm]") { + GIVEN("Procedure with comparison operators") { + std::string nmodl_text = R"( + PROCEDURE foo(x) { + if (x < 10) { + + } else if (x >= 10 && x <= 100) { + + } else if (x == 120) { + + } else if (!(x != 200)) { + + } + } + )"; + + THEN("correct LLVM instructions are produced") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check less than. + std::regex lt(R"(fcmp olt double %(.+), 1\.000000e\+01)"); + REQUIRE(std::regex_search(module_string, m, lt)); + + // Check greater or equal than and logical and. + std::regex ge(R"(fcmp ole double %(.+), 1\.000000e\+02)"); + std::regex logical_and(R"(and i1 %(.+), %(.+))"); + REQUIRE(std::regex_search(module_string, m, ge)); + REQUIRE(std::regex_search(module_string, m, logical_and)); + + // Check equals. + std::regex eq(R"(fcmp oeq double %(.+), 1\.200000e\+02)"); + REQUIRE(std::regex_search(module_string, m, eq)); + + // Check not equals. + std::regex ne(R"(fcmp one double %(.+), 2\.000000e\+02)"); + REQUIRE(std::regex_search(module_string, m, ne)); + } + } +} + +SCENARIO("If/Else", "[visitor][llvm]") { + GIVEN("Function with only if statement") { + std::string nmodl_text = R"( + FUNCTION foo(y) { + LOCAL x + x = 100 + if (y == 20) { + x = 20 + } + foo = x + y + } + )"; + + THEN("correct LLVM instructions are produced") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + std::regex cond_br( + "br i1 %2, label %3, label %4\n" + "\n" + "3:.*\n" + " store double 2\\.000000e\\+01, double\\* %x.*\n" + " br label %4\n" + "\n" + "4:"); + REQUIRE(std::regex_search(module_string, m, cond_br)); + } + } + + GIVEN("Function with both if and else statements") { + std::string nmodl_text = R"( + FUNCTION sign(x) { + LOCAL s + if (x < 0) { + s = -1 + } else { + s = 1 + } + sign = s + } + )"; + + THEN("correct LLVM instructions are produced") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + std::regex if_else_br( + "br i1 %2, label %3, label %4\n" + "\n" + "3:.*\n" + " store double -1\\.000000e\\+00, double\\* %s.*\n" + " br label %5\n" + "\n" + "4:.*\n" + " store double 1\\.000000e\\+00, double\\* %s.*\n" + " br label %5\n" + "\n" + "5:"); + REQUIRE(std::regex_search(module_string, m, if_else_br)); + } + } + + GIVEN("Function with both if and else if statements") { + std::string nmodl_text = R"( + FUNCTION bar(x) { + LOCAL s + s = -1 + if (x <= 0) { + s = 0 + } else if (0 < x && x <= 1) { + s = 1 + } + bar = s + } + )"; + + THEN("correct LLVM instructions are produced") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + std::regex if_else_if( + "br i1 %2, label %3, label %4\n" + "\n" + "3:.*\n" + " .*\n" + " br label %12\n" + "\n" + "4:.*\n" + " .*\n" + " .*\n" + " .*\n" + " .*\n" + " %.+ = and i1 %.+, %.+\n" + " br i1 %.+, label %10, label %11\n" + "\n" + "10:.*\n" + " .*\n" + " br label %11\n" + "\n" + "11:.*\n" + " br label %12\n" + "\n" + "12:"); + REQUIRE(std::regex_search(module_string, m, if_else_if)); + } + } + + GIVEN("Function with if, else if anf else statements") { + std::string nmodl_text = R"( + FUNCTION bar(x) { + LOCAL s + if (x <= 0) { + s = 0 + } else if (0 < x && x <= 1) { + s = 1 + } else { + s = 100 + } + bar = s + } + )"; + + THEN("correct LLVM instructions are produced") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + std::regex if_else_if_else( + "br i1 %2, label %3, label %4\n" + "\n" + "3:.*\n" + " .*\n" + " br label %13\n" + "\n" + "4:.*\n" + " .*\n" + " .*\n" + " .*\n" + " .*\n" + " %9 = and i1 %.+, %.+\n" + " br i1 %9, label %10, label %11\n" + "\n" + "10:.*\n" + " .*\n" + " br label %12\n" + "\n" + "11:.*\n" + " .*\n" + " br label %12\n" + "\n" + "12:.*\n" + " br label %13\n" + "\n" + "13:"); + REQUIRE(std::regex_search(module_string, m, if_else_if_else)); + } + } +} + +//============================================================================= +// FunctionBlock +//============================================================================= + +SCENARIO("Function", "[visitor][llvm]") { + GIVEN("Simple function with arguments") { + std::string nmodl_text = R"( + FUNCTION foo(x) { + foo = x + } + )"; + + THEN("function is produced with arguments allocated on stack and a return instruction") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check function signature. The return type should be the default double type. + std::regex function_signature(R"(define double @foo\(double %x[0-9].*\) \{)"); + REQUIRE(std::regex_search(module_string, m, function_signature)); + + // Check that function arguments are allocated on the local stack. + std::regex alloca_instr(R"(%x = alloca double)"); + std::regex store_instr(R"(store double %x[0-9].*, double\* %x)"); + REQUIRE(std::regex_search(module_string, m, alloca_instr)); + REQUIRE(std::regex_search(module_string, m, store_instr)); + + // Check the return variable has also been allocated. + std::regex ret_instr(R"(%ret_foo = alloca double)"); + + // Check that the return value has been loaded and passed to terminator. + std::regex loaded(R"(%2 = load double, double\* %ret_foo)"); + std::regex terminator(R"(ret double %2)"); + REQUIRE(std::regex_search(module_string, m, loaded)); + REQUIRE(std::regex_search(module_string, m, terminator)); + } + } +} + +//============================================================================= +// FunctionCall +//============================================================================= + +SCENARIO("Function call", "[visitor][llvm]") { + GIVEN("A call to procedure") { + std::string nmodl_text = R"( + PROCEDURE bar() {} + FUNCTION foo() { + bar() + } + )"; + + THEN("an int call instruction is created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check for call instruction. + std::regex call(R"(call i32 @bar\(\))"); + REQUIRE(std::regex_search(module_string, m, call)); + } + } + + GIVEN("A call to function declared below the caller") { + std::string nmodl_text = R"( + FUNCTION foo(x) { + foo = 4 * bar() + } + FUNCTION bar() { + bar = 5 + } + )"; + + THEN("a correct call instruction is created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check for call instruction. + std::regex call(R"(%[0-9]+ = call double @bar\(\))"); + REQUIRE(std::regex_search(module_string, m, call)); + } + } + + GIVEN("A call to function with arguments") { + std::string nmodl_text = R"( + FUNCTION foo(x, y) { + foo = 4 * x - y + } + FUNCTION bar(i) { + bar = foo(i, 4) + } + )"; + + THEN("arguments are processed before the call and passed to call instruction") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check correct arguments. + std::regex i(R"(%1 = load double, double\* %i)"); + std::regex call(R"(call double @foo\(double %1, double 4.000000e\+00\))"); + REQUIRE(std::regex_search(module_string, m, i)); + REQUIRE(std::regex_search(module_string, m, call)); + } + } + + GIVEN("A call to external method") { + std::string nmodl_text = R"( + FUNCTION nmodl_ceil(x) { + nmodl_ceil = ceil(x) + } + + FUNCTION nmodl_cos(x) { + nmodl_cos = cos(x) + } + + FUNCTION nmodl_exp(x) { + nmodl_exp = exp(x) + } + + FUNCTION nmodl_fabs(x) { + nmodl_fabs = fabs(x) + } + + FUNCTION nmodl_floor(x) { + nmodl_floor = floor(x) + } + + FUNCTION nmodl_log(x) { + nmodl_log = log(x) + } + + FUNCTION nmodl_log10(x) { + nmodl_log10 = log10(x) + } + + FUNCTION nmodl_pow(x, y) { + nmodl_pow = pow(x, y) + } + + FUNCTION nmodl_sin(x) { + nmodl_sin = sin(x) + } + + FUNCTION nmodl_sqrt(x) { + nmodl_sqrt = sqrt(x) + } + )"; + + THEN("LLVM intrinsic corresponding to this method is created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check for intrinsic declarations. + std::regex ceil(R"(declare double @llvm\.ceil\.f64\(double\))"); + std::regex cos(R"(declare double @llvm\.cos\.f64\(double\))"); + std::regex exp(R"(declare double @llvm\.exp\.f64\(double\))"); + std::regex fabs(R"(declare double @llvm\.fabs\.f64\(double\))"); + std::regex floor(R"(declare double @llvm\.floor\.f64\(double\))"); + std::regex log(R"(declare double @llvm\.log\.f64\(double\))"); + std::regex log10(R"(declare double @llvm\.log10\.f64\(double\))"); + std::regex pow(R"(declare double @llvm\.pow\.f64\(double, double\))"); + std::regex sin(R"(declare double @llvm\.sin\.f64\(double\))"); + std::regex sqrt(R"(declare double @llvm\.sqrt\.f64\(double\))"); + REQUIRE(std::regex_search(module_string, m, ceil)); + REQUIRE(std::regex_search(module_string, m, cos)); + REQUIRE(std::regex_search(module_string, m, exp)); + REQUIRE(std::regex_search(module_string, m, fabs)); + REQUIRE(std::regex_search(module_string, m, floor)); + REQUIRE(std::regex_search(module_string, m, log)); + REQUIRE(std::regex_search(module_string, m, log10)); + REQUIRE(std::regex_search(module_string, m, pow)); + REQUIRE(std::regex_search(module_string, m, sin)); + REQUIRE(std::regex_search(module_string, m, sqrt)); + + // Check the correct call is made. + std::regex ceil_call(R"(call double @llvm\.ceil\.f64\(double %[0-9]+\))"); + std::regex cos_call(R"(call double @llvm\.cos\.f64\(double %[0-9]+\))"); + std::regex exp_call(R"(call double @llvm\.exp\.f64\(double %[0-9]+\))"); + std::regex fabs_call(R"(call double @llvm\.fabs\.f64\(double %[0-9]+\))"); + std::regex floor_call(R"(call double @llvm\.floor\.f64\(double %[0-9]+\))"); + std::regex log_call(R"(call double @llvm\.log\.f64\(double %[0-9]+\))"); + std::regex log10_call(R"(call double @llvm\.log10\.f64\(double %[0-9]+\))"); + std::regex pow_call(R"(call double @llvm\.pow\.f64\(double %[0-9]+, double %[0-9]+\))"); + std::regex sin_call(R"(call double @llvm\.sin\.f64\(double %[0-9]+\))"); + std::regex sqrt_call(R"(call double @llvm\.sqrt\.f64\(double %[0-9]+\))"); + REQUIRE(std::regex_search(module_string, m, ceil_call)); + REQUIRE(std::regex_search(module_string, m, cos_call)); + REQUIRE(std::regex_search(module_string, m, exp_call)); + REQUIRE(std::regex_search(module_string, m, fabs_call)); + REQUIRE(std::regex_search(module_string, m, floor_call)); + REQUIRE(std::regex_search(module_string, m, log_call)); + REQUIRE(std::regex_search(module_string, m, log10_call)); + REQUIRE(std::regex_search(module_string, m, pow_call)); + REQUIRE(std::regex_search(module_string, m, sin_call)); + REQUIRE(std::regex_search(module_string, m, sqrt_call)); + } + } + + GIVEN("A call to printf") { + std::string nmodl_text = R"( + PROCEDURE bar() { + LOCAL i + i = 0 + printf("foo") + printf("bar %d", i) + } + )"; + + THEN("printf is declared and global string values are created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check for global string values. + std::regex str1( + R"(@[0-9]+ = private unnamed_addr constant \[6 x i8\] c\"\\22foo\\22\\00\")"); + std::regex str2( + R"(@[0-9]+ = private unnamed_addr constant \[9 x i8\] c\"\\22bar %d\\22\\00\")"); + REQUIRE(std::regex_search(module_string, m, str1)); + REQUIRE(std::regex_search(module_string, m, str2)); + + // Check for printf declaration. + std::regex declaration(R"(declare i32 @printf\(i8\*, \.\.\.\))"); + REQUIRE(std::regex_search(module_string, m, declaration)); + + // Check the correct calls are made. + std::regex call1( + R"(call i32 \(i8\*, \.\.\.\) @printf\(i8\* getelementptr inbounds \(\[6 x i8\], \[6 x i8\]\* @[0-9]+, i32 0, i32 0\)\))"); + std::regex call2( + R"(call i32 \(i8\*, \.\.\.\) @printf\(i8\* getelementptr inbounds \(\[9 x i8\], \[9 x i8\]\* @[0-9]+, i32 0, i32 0\), double %[0-9]+\))"); + REQUIRE(std::regex_search(module_string, m, call1)); + REQUIRE(std::regex_search(module_string, m, call2)); + } + } + + GIVEN("A call to function with the wrong number of arguments") { + std::string nmodl_text = R"( + FUNCTION foo(x, y) { + foo = 4 * x - y + } + FUNCTION bar(i) { + bar = foo(i) + } + )"; + + THEN("a runtime error is thrown") { + REQUIRE_THROWS_AS(run_llvm_visitor(nmodl_text), std::runtime_error); + } + } +} + +//============================================================================= +// IndexedName +//============================================================================= + +SCENARIO("Indexed name", "[visitor][llvm]") { + GIVEN("Procedure with a local array variable") { + std::string nmodl_text = R"( + PROCEDURE foo() { + LOCAL x[2] + } + )"; + + THEN("array is allocated") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + std::regex array(R"(%x = alloca \[2 x double\])"); + REQUIRE(std::regex_search(module_string, m, array)); + } + } + + GIVEN("Procedure with a local array assignment") { + std::string nmodl_text = R"( + PROCEDURE foo() { + LOCAL x[2] + x[10 - 10] = 1 + x[1] = 3 + } + )"; + + THEN("element is stored to the array") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check GEPs are created correctly to get the addresses of array elements. + std::regex GEP1( + R"(%1 = getelementptr inbounds \[2 x double\], \[2 x double\]\* %x, i64 0, i64 0)"); + std::regex GEP2( + R"(%2 = getelementptr inbounds \[2 x double\], \[2 x double\]\* %x, i64 0, i64 1)"); + REQUIRE(std::regex_search(module_string, m, GEP1)); + REQUIRE(std::regex_search(module_string, m, GEP2)); + + // Check the value is stored to the correct addresses. + std::regex store1(R"(store double 1.000000e\+00, double\* %1)"); + std::regex store2(R"(store double 3.000000e\+00, double\* %2)"); + REQUIRE(std::regex_search(module_string, m, store1)); + REQUIRE(std::regex_search(module_string, m, store2)); + } + } + + GIVEN("Procedure with a assignment of array element") { + std::string nmodl_text = R"( + PROCEDURE foo() { + LOCAL x[2], y + x[1] = 3 + y = x[1] + } + )"; + + THEN("array element is stored to the variable") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check GEP is created correctly to pint at array element. + std::regex GEP( + R"(%2 = getelementptr inbounds \[2 x double\], \[2 x double\]\* %x, i64 0, i64 1)"); + REQUIRE(std::regex_search(module_string, m, GEP)); + + // Check the value is loaded from the pointer. + std::regex load(R"(%3 = load double, double\* %2)"); + REQUIRE(std::regex_search(module_string, m, load)); + + // Check the value is stored to the the variable. + std::regex store(R"(store double %3, double\* %y)"); + REQUIRE(std::regex_search(module_string, m, store)); + } + } +} + +//============================================================================= +// LocalList and LocalVar +//============================================================================= + +SCENARIO("Local variable", "[visitor][llvm]") { + GIVEN("Procedure with some local variables") { + std::string nmodl_text = R"( + PROCEDURE local() { + LOCAL i, j + } + )"; + + THEN("local variables are allocated on the stack") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check stack allocations for i and j + std::regex i(R"(%i = alloca double)"); + std::regex j(R"(%j = alloca double)"); + REQUIRE(std::regex_search(module_string, m, i)); + REQUIRE(std::regex_search(module_string, m, j)); + } + } +} + +//============================================================================= +// ProcedureBlock +//============================================================================= + +SCENARIO("Procedure", "[visitor][llvm]") { + GIVEN("Empty procedure with no arguments") { + std::string nmodl_text = R"( + PROCEDURE empty() {} + )"; + + THEN("a function returning 0 integer is produced") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check procedure has empty body with a dummy 0 allocation. + std::regex signature(R"(define i32 @empty)"); + std::regex alloc(R"(%ret_empty = alloca i32)"); + std::regex store(R"(store i32 0, i32\* %ret_empty)"); + std::regex load(R"(%1 = load i32, i32\* %ret_empty)"); + std::regex ret(R"(ret i32 %1)"); + REQUIRE(std::regex_search(module_string, m, signature)); + REQUIRE(std::regex_search(module_string, m, alloc)); + REQUIRE(std::regex_search(module_string, m, store)); + REQUIRE(std::regex_search(module_string, m, load)); + REQUIRE(std::regex_search(module_string, m, ret)); + } + } + + GIVEN("Empty procedure with arguments") { + std::string nmodl_text = R"( + PROCEDURE with_argument(x) {} + )"; + + THEN("int function is produced with arguments allocated on stack") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check procedure signature. + std::regex function_signature(R"(define i32 @with_argument\(double %x[0-9].*\) \{)"); + REQUIRE(std::regex_search(module_string, m, function_signature)); + + // Check dummy return. + std::regex dummy_alloca(R"(%ret_with_argument = alloca i32)"); + std::regex dummy_store(R"(store i32 0, i32\* %ret_with_argument)"); + std::regex dummy_load(R"(%1 = load i32, i32\* %ret_with_argument)"); + std::regex ret(R"(ret i32 %1)"); + REQUIRE(std::regex_search(module_string, m, dummy_alloca)); + REQUIRE(std::regex_search(module_string, m, dummy_store)); + REQUIRE(std::regex_search(module_string, m, dummy_load)); + REQUIRE(std::regex_search(module_string, m, ret)); + + // Check that procedure arguments are allocated on the local stack. + std::regex alloca_instr(R"(%x = alloca double)"); + std::regex store_instr(R"(store double %x[0-9].*, double\* %x)"); + REQUIRE(std::regex_search(module_string, m, alloca_instr)); + REQUIRE(std::regex_search(module_string, m, store_instr)); + } + } +} + +//============================================================================= +// UnaryExpression +//============================================================================= + +SCENARIO("Unary expression", "[visitor][llvm]") { + GIVEN("Procedure with negation") { + std::string nmodl_text = R"( + PROCEDURE negation(a) { + LOCAL i + i = -a + } + )"; + + THEN("fneg instruction is created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + std::regex allocation(R"(%1 = load double, double\* %a)"); + REQUIRE(std::regex_search(module_string, m, allocation)); + + // llvm v9 and llvm v11 implementation for negation + std::regex negation_v9(R"(%2 = fsub double -0.000000e\+00, %1)"); + std::regex negation_v11(R"(fneg double %1)"); + bool result = std::regex_search(module_string, m, negation_v9) || + std::regex_search(module_string, m, negation_v11); + REQUIRE(result == true); + } + } +} + +//============================================================================= +// WhileStatement +//============================================================================= + +SCENARIO("While", "[visitor][llvm]") { + GIVEN("Procedure with a simple while loop") { + std::string nmodl_text = R"( + FUNCTION loop() { + LOCAL i + i = 0 + WHILE (i < 10) { + i = i + 1 + } + loop = 0 + } + )"; + + THEN("correct loop is created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + std::regex loop( + " br label %1\n" + "\n" + "1:.*\n" + " %2 = load double, double\\* %i.*\n" + " %3 = fcmp olt double %2, 1\\.000000e\\+01\n" + " br i1 %3, label %4, label %7\n" + "\n" + "4:.*\n" + " %5 = load double, double\\* %i.*\n" + " %6 = fadd double %5, 1\\.000000e\\+00\n" + " store double %6, double\\* %i.*\n" + " br label %1\n" + "\n" + "7:.*\n" + " store double 0\\.000000e\\+00, double\\* %ret_loop.*\n"); + // Check that 3 blocks are created: header, body and exit blocks. Also, there must be + // a backedge from the body to the header. + REQUIRE(std::regex_search(module_string, m, loop)); + } + } +} + +//============================================================================= +// State scalar kernel +//============================================================================= + +SCENARIO("Scalar state kernel", "[visitor][llvm]") { + GIVEN("A neuron state update") { + std::string nmodl_text = R"( + NEURON { + SUFFIX hh + NONSPECIFIC_CURRENT il + RANGE minf, mtau, gl, el + } + + STATE { + m + } + + ASSIGNED { + v (mV) + minf + mtau (ms) + } + + BREAKPOINT { + SOLVE states METHOD cnexp + il = gl * (v - el) + } + + DERIVATIVE states { + m = (minf-m) / mtau + } + )"; + + THEN("a kernel with instance struct as an argument and a FOR loop is created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check the struct type with correct attributes and the kernel declaration. + std::regex struct_type( + "%.*__instance_var__type = type \\{ double\\*, double\\*, double\\*, double\\*, " + "double\\*, double\\*, double\\*, i32\\*, double, double, double, i32, i32 \\}"); + std::regex kernel_declaration( + R"(define void @nrn_state_hh\(%.*__instance_var__type\* noalias nocapture readonly .*\) #0)"); + REQUIRE(std::regex_search(module_string, m, struct_type)); + REQUIRE(std::regex_search(module_string, m, kernel_declaration)); + + // Check kernel attributes. + std::regex kernel_attributes(R"(attributes #0 = \{ nofree nounwind \})"); + REQUIRE(std::regex_search(module_string, m, kernel_attributes)); + + // Check for correct variables initialisation and a branch to condition block. + std::regex id_initialisation(R"(%id = alloca i32)"); + std::regex node_id_initialisation(R"(%node_id = alloca i32)"); + std::regex v_initialisation(R"(%v = alloca double)"); + std::regex br(R"(br label %for\.cond)"); + REQUIRE(std::regex_search(module_string, m, id_initialisation)); + REQUIRE(std::regex_search(module_string, m, node_id_initialisation)); + REQUIRE(std::regex_search(module_string, m, v_initialisation)); + REQUIRE(std::regex_search(module_string, m, br)); + + // Check condition block: id < mech->node_count, and a conditional branch to loop body + // or exit. + std::regex condition( + " %.* = load %.*__instance_var__type\\*, %.*__instance_var__type\\*\\* %.*,.*\n" + " %.* = getelementptr inbounds %.*__instance_var__type, " + "%.*__instance_var__type\\* " + "%.*, i32 0, i32 [0-9]+\n" + " %.* = load i32, i32\\* %.*,.*\n" + " %.* = load i32, i32\\* %id,.*\n" + " %.* = icmp slt i32 %.*, %.*"); + std::regex cond_br(R"(br i1 %.*, label %for\.body, label %for\.exit)"); + REQUIRE(std::regex_search(module_string, m, condition)); + REQUIRE(std::regex_search(module_string, m, cond_br)); + + // Check that loop metadata is attached to the scalar kernel. + std::regex loop_metadata(R"(!llvm\.loop !0)"); + std::regex loop_metadata_self_reference(R"(!0 = distinct !\{!0, !1\})"); + std::regex loop_metadata_disable_vectorization( + R"(!1 = !\{!\"llvm\.loop\.vectorize\.enable\", i1 false\})"); + REQUIRE(std::regex_search(module_string, m, loop_metadata)); + REQUIRE(std::regex_search(module_string, m, loop_metadata_self_reference)); + REQUIRE(std::regex_search(module_string, m, loop_metadata_disable_vectorization)); + + // Check for correct loads from the struct with GEPs. + std::regex load_from_struct( + " %.* = load %.*__instance_var__type\\*, %.*__instance_var__type\\*\\* %.*\n" + " %.* = getelementptr inbounds %.*__instance_var__type, " + "%.*__instance_var__type\\* %.*, i32 0, i32 [0-9]+\n" + " %.* = load i32, i32\\* %id,.*\n" + " %.* = sext i32 %.* to i64\n" + " %.* = load (i32|double)\\*, (i32|double)\\*\\* %.*\n" + " %.* = getelementptr inbounds (i32|double), (i32|double)\\* %.*, i64 %.*\n" + " %.* = load (i32|double), (i32|double)\\* %.*"); + REQUIRE(std::regex_search(module_string, m, load_from_struct)); + + // Check induction variable is incremented in increment block. + std::regex increment( + "for.inc:.*\n" + " %.* = load i32, i32\\* %id,.*\n" + " %.* = add i32 %.*, 1\n" + " store i32 %.*, i32\\* %id,.*\n" + " br label %for\\.cond"); + REQUIRE(std::regex_search(module_string, m, increment)); + + // Check exit block. + std::regex exit( + "for\\.exit[0-9]*:.*\n" + " ret void"); + REQUIRE(std::regex_search(module_string, m, exit)); + } + } +} + +//============================================================================= +// Gather for vectorised kernel +//============================================================================= + +SCENARIO("Vectorised simple kernel", "[visitor][llvm]") { + GIVEN("An indirect indexing of voltage") { + std::string nmodl_text = R"( + NEURON { + SUFFIX hh + NONSPECIFIC_CURRENT i + } + + STATE {} + + ASSIGNED { + v (mV) + } + + BREAKPOINT { + SOLVE states METHOD cnexp + i = 2 + } + + DERIVATIVE states {} + )"; + + THEN("a gather instructions is created") { + std::string module_string = run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/false, + /*vector_width=*/4); + std::smatch m; + + // Check that no loop metadata is attached. + std::regex loop_metadata(R"(!llvm\.loop !.*)"); + REQUIRE(!std::regex_search(module_string, m, loop_metadata)); + + // Check gather intrinsic is correctly declared. + std::regex declaration( + R"(declare <4 x double> @llvm\.masked\.gather\.v4f64\.v4p0f64\(<4 x double\*>, i32 immarg, <4 x i1>, <4 x double>\) )"); + REQUIRE(std::regex_search(module_string, m, declaration)); + + // Check that the indices vector is created correctly and extended to i64. + std::regex index_load(R"(load <4 x i32>, <4 x i32>\* %node_id)"); + std::regex sext(R"(sext <4 x i32> %.* to <4 x i64>)"); + REQUIRE(std::regex_search(module_string, m, index_load)); + REQUIRE(std::regex_search(module_string, m, sext)); + + // Check that the access to `voltage` is performed via gather instruction. + // v = mech->voltage[node_id] + std::regex gather( + "call <4 x double> @llvm\\.masked\\.gather\\.v4f64\\.v4p0f64\\(" + "<4 x double\\*> %.*, i32 1, <4 x i1> , <4 x " + "double> undef\\)"); + REQUIRE(std::regex_search(module_string, m, gather)); + } + } +} + +//============================================================================= +// Scatter for vectorised kernel +//============================================================================= + +SCENARIO("Vectorised simple kernel with ion writes", "[visitor][llvm]") { + GIVEN("An indirect indexing of ca ion") { + std::string nmodl_text = R"( + NEURON { + SUFFIX hh + USEION ca WRITE cai + } + + BREAKPOINT { + SOLVE states METHOD cnexp + } + + DERIVATIVE states {} + )"; + + THEN("a scatter instructions is created") { + std::string module_string = run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/false, + /*vector_width=*/4); + std::smatch m; + + // Check scatter intrinsic is correctly declared. + std::regex declaration( + R"(declare void @llvm\.masked\.scatter\.v4f64\.v4p0f64\(<4 x double>, <4 x double\*>, i32 immarg, <4 x i1>\))"); + REQUIRE(std::regex_search(module_string, m, declaration)); + + // Check that the indices vector is created correctly and extended to i64. + std::regex index_load(R"(load <4 x i32>, <4 x i32>\* %ion_cai_id)"); + std::regex sext(R"(sext <4 x i32> %.* to <4 x i64>)"); + REQUIRE(std::regex_search(module_string, m, index_load)); + REQUIRE(std::regex_search(module_string, m, sext)); + + // Check that store to `ion_cai` is performed via scatter instruction. + // ion_cai[ion_cai_id] = cai[id] + std::regex scatter( + "call void @llvm\\.masked\\.scatter\\.v4f64\\.v4p0f64\\(<4 x double> %.*, <4 x " + "double\\*> %.*, i32 1, <4 x i1> \\)"); + REQUIRE(std::regex_search(module_string, m, scatter)); + } + } +} + +//============================================================================= +// Vectorised kernel with simple control flow +//============================================================================= + +SCENARIO("Vectorised simple kernel with control flow", "[visitor][llvm]") { + GIVEN("A single if/else statement") { + std::string nmodl_text = R"( + NEURON { + SUFFIX test + } + + STATE { + y + } + + BREAKPOINT { + SOLVE states METHOD cnexp + } + + DERIVATIVE states { + IF (y < 0) { + y = y + 7 + } ELSE { + y = v + } + } + )"; + + THEN("masked load and stores are created") { + std::string module_string = run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/true, + /*vector_width=*/8); + std::smatch m; + + // Check masked load/store intrinsics are correctly declared. + std::regex masked_load( + R"(declare <8 x float> @llvm\.masked\.load\.v8f32\.p0v8f32\(<8 x float>\*, i32 immarg, <8 x i1>, <8 x float>\))"); + std::regex masked_store( + R"(declare void @llvm.masked\.store\.v8f32\.p0v8f32\(<8 x float>, <8 x float>\*, i32 immarg, <8 x i1>\))"); + REQUIRE(std::regex_search(module_string, m, masked_load)); + REQUIRE(std::regex_search(module_string, m, masked_store)); + + // Check true direction instructions are predicated with mask. + // IF (mech->y[id] < 0) { + // mech->y[id] = mech->y[id] + 7 + std::regex mask(R"(%30 = fcmp olt <8 x float> %.*, zeroinitializer)"); + std::regex true_load( + R"(call <8 x float> @llvm\.masked\.load\.v8f32\.p0v8f32\(<8 x float>\* %.*, i32 1, <8 x i1> %30, <8 x float> undef\))"); + std::regex true_store( + R"(call void @llvm\.masked\.store\.v8f32\.p0v8f32\(<8 x float> %.*, <8 x float>\* %.*, i32 1, <8 x i1> %30\))"); + REQUIRE(std::regex_search(module_string, m, mask)); + REQUIRE(std::regex_search(module_string, m, true_load)); + REQUIRE(std::regex_search(module_string, m, true_store)); + + // Check false direction instructions are predicated with inverted mask. + // } ELSE { + // mech->y[id] = v + // } + std::regex inverted_mask( + R"(%47 = xor <8 x i1> %30, )"); + std::regex false_load( + R"(call <8 x float> @llvm\.masked\.load\.v8f32\.p0v8f32\(<8 x float>\* %v, i32 1, <8 x i1> %47, <8 x float> undef\))"); + std::regex false_store( + R"(call void @llvm\.masked\.store\.v8f32\.p0v8f32\(<8 x float> %.*, <8 x float>\* %.*, i32 1, <8 x i1> %47\))"); + } + } +} + +//============================================================================= +// Derivative block : test optimization +//============================================================================= + +SCENARIO("Scalar derivative block", "[visitor][llvm][derivative]") { + GIVEN("After LLVM helper visitor transformations") { + std::string nmodl_text = R"( + NEURON { + SUFFIX hh + NONSPECIFIC_CURRENT il + RANGE minf, mtau + } + STATE { + m + } + ASSIGNED { + v (mV) + minf + mtau (ms) + } + BREAKPOINT { + SOLVE states METHOD cnexp + il = 2 + } + DERIVATIVE states { + m = (minf-m)/mtau + } + )"; + + std::string expected_loop = R"( + for(id = 0; idnode_count; id = id+1) { + node_id = mech->node_index[id] + v = mech->voltage[node_id] + mech->m[id] = (mech->minf[id]-mech->m[id])/mech->mtau[id] + })"; + + THEN("a single scalar loops is constructed") { + auto result = run_llvm_visitor_helper(nmodl_text, + /*vector_width=*/1, + {ast::AstNodeType::CODEGEN_FOR_STATEMENT}); + REQUIRE(result.size() == 1); + + auto main_loop = reindent_text(to_nmodl(result[0])); + REQUIRE(main_loop == reindent_text(expected_loop)); + } + } +} + +SCENARIO("Vectorised derivative block", "[visitor][llvm][derivative]") { + GIVEN("After LLVM helper visitor transformations") { + std::string nmodl_text = R"( + NEURON { + SUFFIX hh + NONSPECIFIC_CURRENT il + RANGE minf, mtau + } + STATE { + m + } + ASSIGNED { + v (mV) + minf + mtau (ms) + } + BREAKPOINT { + SOLVE states METHOD cnexp + il = 2 + } + DERIVATIVE states { + m = (minf-m)/mtau + } + )"; + + std::string expected_main_loop = R"( + for(id = 0; idnode_count-7; id = id+8) { + node_id = mech->node_index[id] + v = mech->voltage[node_id] + mech->m[id] = (mech->minf[id]-mech->m[id])/mech->mtau[id] + })"; + std::string expected_epilogue_loop = R"( + for(; idnode_count; id = id+1) { + epilogue_node_id = mech->node_index[id] + epilogue_v = mech->voltage[epilogue_node_id] + mech->m[id] = (mech->minf[id]-mech->m[id])/mech->mtau[id] + })"; + + + THEN("vector and epilogue scalar loops are constructed") { + auto result = run_llvm_visitor_helper(nmodl_text, + /*vector_width=*/8, + {ast::AstNodeType::CODEGEN_FOR_STATEMENT}); + REQUIRE(result.size() == 2); + + auto main_loop = reindent_text(to_nmodl(result[0])); + REQUIRE(main_loop == reindent_text(expected_main_loop)); + + auto epilogue_loop = reindent_text(to_nmodl(result[1])); + REQUIRE(epilogue_loop == reindent_text(expected_epilogue_loop)); + } + } +} + +//============================================================================= +// Vector library calls. +//============================================================================= + +SCENARIO("Vector library calls", "[visitor][llvm][vector_lib]") { + GIVEN("A vector LLVM intrinsic") { + std::string nmodl_text = R"( + NEURON { + SUFFIX hh + NONSPECIFIC_CURRENT il + } + STATE { + m + } + ASSIGNED { + v (mV) + } + BREAKPOINT { + SOLVE states METHOD cnexp + il = 2 + } + DERIVATIVE states { + m = exp(m) + } + )"; + + THEN("it is replaced with an appropriate vector library call") { + std::smatch m; + + // Check exponential intrinsic is created. + std::string no_library_module_str = run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/false, + /*vector_width=*/2); + std::regex exp_decl(R"(declare <2 x double> @llvm\.exp\.v2f64\(<2 x double>\))"); + std::regex exp_call(R"(call <2 x double> @llvm\.exp\.v2f64\(<2 x double> .*\))"); + REQUIRE(std::regex_search(no_library_module_str, m, exp_decl)); + REQUIRE(std::regex_search(no_library_module_str, m, exp_call)); + +#if LLVM_VERSION_MAJOR >= 13 + // Check exponential calls are replaced with calls to SVML library. + std::string svml_library_module_str = run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/false, + /*vector_width=*/2, + /*vec_lib=*/"SVML"); + std::regex svml_exp_decl(R"(declare <2 x double> @__svml_exp2\(<2 x double>\))"); + std::regex svml_exp_call(R"(call <2 x double> @__svml_exp2\(<2 x double> .*\))"); + REQUIRE(std::regex_search(svml_library_module_str, m, svml_exp_decl)); + REQUIRE(std::regex_search(svml_library_module_str, m, svml_exp_call)); + REQUIRE(!std::regex_search(svml_library_module_str, m, exp_call)); + + // Check that supported exponential calls are replaced with calls to MASSV library (i.e. + // operating on vector of width 2). + std::string massv2_library_module_str = run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/false, + /*vector_width=*/2, + /*vec_lib=*/"MASSV"); + std::regex massv2_exp_decl(R"(declare <2 x double> @__expd2_P8\(<2 x double>\))"); + std::regex massv2_exp_call(R"(call <2 x double> @__expd2_P8\(<2 x double> .*\))"); + REQUIRE(std::regex_search(massv2_library_module_str, m, massv2_exp_decl)); + REQUIRE(std::regex_search(massv2_library_module_str, m, massv2_exp_call)); + REQUIRE(!std::regex_search(massv2_library_module_str, m, exp_call)); + + // Check no replacement for MASSV happens for non-supported vector widths. + std::string massv4_library_module_str = run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/false, + /*vector_width=*/4, + /*vec_lib=*/"MASSV"); + std::regex exp4_call(R"(call <4 x double> @llvm\.exp\.v4f64\(<4 x double> .*\))"); + REQUIRE(std::regex_search(massv4_library_module_str, m, exp4_call)); + + // Check correct replacement of @llvm.exp.v4f32 into @vexpf when using Accelerate. + std::string accelerate_library_module_str = + run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/true, + /*vector_width=*/4, + /*vec_lib=*/"Accelerate"); + std::regex accelerate_exp_decl(R"(declare <4 x float> @vexpf\(<4 x float>\))"); + std::regex accelerate_exp_call(R"(call <4 x float> @vexpf\(<4 x float> .*\))"); + std::regex fexp_call(R"(call <4 x float> @llvm\.exp\.v4f32\(<4 x float> .*\))"); + REQUIRE(std::regex_search(accelerate_library_module_str, m, accelerate_exp_decl)); + REQUIRE(std::regex_search(accelerate_library_module_str, m, accelerate_exp_call)); + REQUIRE(!std::regex_search(accelerate_library_module_str, m, fexp_call)); + + // Check correct replacement of @llvm.exp.v2f64 into @_ZGV?N?v_exp when using SLEEF. + std::string sleef_library_module_str = run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/false, + /*vector_width=*/2, + /*vec_lib=*/"SLEEF"); +#if defined(__arm64__) || defined(__aarch64__) + std::regex sleef_exp_decl(R"(declare <2 x double> @_ZGVnN2v_exp\(<2 x double>\))"); + std::regex sleef_exp_call(R"(call <2 x double> @_ZGVnN2v_exp\(<2 x double> .*\))"); +#else + std::regex sleef_exp_decl(R"(declare <2 x double> @_ZGVbN2v_exp\(<2 x double>\))"); + std::regex sleef_exp_call(R"(call <2 x double> @_ZGVbN2v_exp\(<2 x double> .*\))"); +#endif + REQUIRE(std::regex_search(sleef_library_module_str, m, sleef_exp_decl)); + REQUIRE(std::regex_search(sleef_library_module_str, m, sleef_exp_call)); + REQUIRE(!std::regex_search(sleef_library_module_str, m, fexp_call)); + + // Check the replacements when using Darwin's libsystem_m. + std::string libsystem_m_library_module_str = + run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/true, + /*vector_width=*/4, + /*vec_lib=*/"libsystem_m"); + std::regex libsystem_m_exp_decl(R"(declare <4 x float> @_simd_exp_f4\(<4 x float>\))"); + std::regex libsystem_m_exp_call(R"(call <4 x float> @_simd_exp_f4\(<4 x float> .*\))"); + REQUIRE(std::regex_search(libsystem_m_library_module_str, m, libsystem_m_exp_decl)); + REQUIRE(std::regex_search(libsystem_m_library_module_str, m, libsystem_m_exp_call)); + REQUIRE(!std::regex_search(libsystem_m_library_module_str, m, fexp_call)); +#endif + } + } +} + +//============================================================================= +// Fast math flags +//============================================================================= + +SCENARIO("Fast math flags", "[visitor][llvm]") { + GIVEN("A function to produce fma and specified math flags") { + std::string nmodl_text = R"( + FUNCTION foo(a, b, c) { + foo = (a * b) + c + } + )"; + + THEN("instructions are generated with the flags set") { + std::string module_string = + run_llvm_visitor(nmodl_text, + /*opt=*/true, + /*use_single_precision=*/false, + /*vector_width=*/1, + /*vec_lib=*/"none", + /*fast_math_flags=*/{"nnan", "contract", "afn"}); + std::smatch m; + + // Check flags for produced 'fmul' and 'fadd' instructions. + std::regex fmul(R"(fmul nnan contract afn double %.*, %.*)"); + std::regex fadd(R"(fadd nnan contract afn double %.*, %.*)"); + REQUIRE(std::regex_search(module_string, m, fmul)); + REQUIRE(std::regex_search(module_string, m, fadd)); + } + } +} + +//============================================================================= +// Optimization : dead code removal +//============================================================================= + +SCENARIO("Dead code removal", "[visitor][llvm][opt]") { + GIVEN("Procedure using local variables, without any side effects") { + std::string nmodl_text = R"( + PROCEDURE add(a, b) { + LOCAL i + i = a + b + } + )"; + + THEN("with optimisation enabled, all ops are eliminated") { + std::string module_string = run_llvm_visitor(nmodl_text, true); + std::smatch m; + + // Check if the values are optimised out + std::regex empty_proc( + R"(define i32 @add\(double %a[0-9].*, double %b[0-9].*\) \{\n(\s)*ret i32 0\n\})"); + REQUIRE(std::regex_search(module_string, m, empty_proc)); + } + } +} + +//============================================================================= +// Inlining: remove inline code blocks +//============================================================================= + +SCENARIO("Removal of inlined functions and procedures", "[visitor][llvm][inline]") { + GIVEN("Simple breakpoint block calling a function and a procedure") { + std::string nmodl_text = R"( + NEURON { + SUFFIX test_inline + RANGE a, b, s + } + ASSIGNED { + a + b + s + } + PROCEDURE test_add(a, b) { + LOCAL i + i = a + b + } + FUNCTION test_sub(a, b) { + test_sub = a - b + } + BREAKPOINT { + SOLVE states METHOD cnexp + } + DERIVATIVE states { + a = 1 + b = 2 + test_add(a, b) + s = test_sub(a, b) + } + )"; + + THEN("when the code is inlined the procedure and function blocks are removed") { + std::string module_string = run_llvm_visitor(nmodl_text, + /*opt=*/false, + /*use_single_precision=*/false, + /*vector_width=*/1, + /*vec_lib=*/"none", + /*fast_math_flags=*/{}, + /*nmodl_inline=*/true); + std::smatch m; + + // Check if the procedure and function declarations are removed + std::regex add_proc(R"(define i32 @test_add\(double %a[0-9].*, double %b[0-9].*\))"); + REQUIRE(!std::regex_search(module_string, m, add_proc)); + std::regex sub_func(R"(define double @test_sub\(double %a[0-9].*, double %b[0-9].*\))"); + REQUIRE(!std::regex_search(module_string, m, sub_func)); + } + } +}