Skip to content

Commit

Permalink
[LLVM][GPU] Added CUDADriver to execute benchmark on GPU (#829)
Browse files Browse the repository at this point in the history
- Added CUDADriver to compile LLVM IR string generated from CodegenLLVMVisitor to PTX string and then execute it using CUDA API
- Ability to select the compilation GPU architecture and then set the proper GPU architecture based on the GPU that is going to be used
- Link `libdevice` math library with GPU LLVM module
- Handles kernel and wrapper functions attributes properly for GPU execution (wrapper function is `kernel` and kernel attribute is `device`)
- Small fixes in InstanceStruct declaration and setup to allocate the pointer variables properly, including the shadow variables
- Adds tests in the CI that run small benchmarks in CPU and GPU on BB5
- Adds replacement of `log` math function for SLEEF and libdevice, `pow` and `fabs` for libdevice
- Adds GPU execution ability in PyJIT
- Small improvement in PyJIT benchmark python script to handle arguments and GPU execution
- Separated benchmark info from benchmark driver
- Added hh and expsyn mod files in benchmarking tests
  • Loading branch information
iomaganaris committed May 12, 2022
1 parent 1c85efc commit be30888
Show file tree
Hide file tree
Showing 29 changed files with 1,045 additions and 125 deletions.
37 changes: 25 additions & 12 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ trigger cvf:
variables:
SPACK_PACKAGE: nmodl
SPACK_PACKAGE_SPEC: ~legacy-unit+python+llvm
SPACK_EXTRA_MODULES: llvm
SPACK_INSTALL_EXTRA_FLAGS: -v

spack_setup:
Expand All @@ -45,14 +44,6 @@ build:intel:
variables:
SPACK_PACKAGE_COMPILER: intel

build:gcc:
extends:
- .spack_build
- .spack_nmodl
variables:
SPACK_PACKAGE_COMPILER: gcc
SPACK_PACKAGE_DEPENDENCIES: ^bison%gcc^flex%gcc^py-jinja2%gcc^py-sympy%gcc^py-pyyaml%gcc

.nmodl_tests:
variables:
# https://github.com/BlueBrain/nmodl/issues/737
Expand All @@ -64,8 +55,30 @@ test:intel:
- .nmodl_tests
needs: ["build:intel"]

test:gcc:
.benchmark_config:
variables:
bb5_ntasks: 1
bb5_cpus_per_task: 1
bb5_memory: 16G
bb5_exclusive: full
bb5_constraint: gpu_32g # CascadeLake CPU & V100 GPU node

.build_allocation:
variables:
bb5_ntasks: 2 # so we block 16 cores
bb5_cpus_per_task: 8 # ninja -j {this}
bb5_memory: 76G # ~16*384/80

build_cuda:gcc:
extends: [.spack_build, .build_allocation]
variables:
SPACK_PACKAGE: nmodl
SPACK_PACKAGE_SPEC: ~legacy-unit+python+llvm+llvm_cuda
SPACK_INSTALL_EXTRA_FLAGS: -v
SPACK_PACKAGE_COMPILER: gcc

test_benchmark:gcc:
extends:
- .benchmark_config
- .ctest
- .nmodl_tests
needs: ["build:gcc"]
needs: ["build_cuda:gcc"]
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ option(NMODL_ENABLE_PYTHON_BINDINGS "Enable pybind11 based python bindings" ON)
option(NMODL_ENABLE_LEGACY_UNITS "Use original faraday, R, etc. instead of 2019 nist constants" OFF)
option(NMODL_ENABLE_LLVM "Enable LLVM based code generation" ON)
option(NMODL_ENABLE_LLVM_GPU "Enable LLVM based GPU code generation" ON)
option(NMODL_ENABLE_LLVM_CUDA "Enable LLVM CUDA backend to run GPU benchmark" OFF)
option(NMODL_ENABLE_JIT_EVENT_LISTENERS "Enable JITEventListener for Perf and Vtune" OFF)

if(NMODL_ENABLE_LEGACY_UNITS)
Expand Down Expand Up @@ -162,6 +163,7 @@ if(NMODL_ENABLE_LLVM)
if(NMODL_ENABLE_LLVM_CUDA)
enable_language(CUDA)
find_package(CUDAToolkit)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
add_definitions(-DNMODL_LLVM_CUDA_BACKEND)
endif()
endif()
Expand Down
25 changes: 24 additions & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ To build the project from source, a modern C++ compiler with C++14 support is ne

- flex (>=2.6)
- bison (>=3.0)
- CMake (>=3.15)
- CMake (>=3.17)
- Python (>=3.6)
- Python packages : jinja2 (>=2.10), pyyaml (>=3.13), pytest (>=4.0.0), sympy (>=1.3), textwrap

Expand Down Expand Up @@ -141,6 +141,29 @@ export NMODL_WRAPLIB=/opt/nmodl/lib/libpywrapper.so
**Note**: In order for all unit tests to function correctly when building without linking against libpython we must
set `NMODL_PYLIB` before running cmake!

### Using CUDA backend to run benchmarks

`NMODL` supports generating code and compiling it for execution on an `NVIDIA` GPU via its benchmark infrastructure using the `LLVM` backend. To enable the `CUDA` backend to compile and execute the GPU code we need to set the following `CMake` flag during compilation of `NMODL`:
```
-DNMODL_ENABLE_LLVM_CUDA=ON
```

To find the need `CUDA` libraries (`cudart` and `nvrtc`) it's needed to have CUDA Toolkit installed on your system.
This can be done by installing the CUDA Toolkit from the [CUDA Toolkit website](https://developer.nvidia.com/cuda-downloads) or by installing the `CUDA` spack package and loading the corresponding module.

Then given a supported MOD file you can execute the benchmark on GPU in you supported NVIDIA GPU by running the following command:
```
./bin/nmodl <file>.mod llvm --no-debug --ir --opt-level-ir 3 gpu --target-arch "sm_80" --name "nvptx64" --math-library libdevice benchmark --run --libs "${CUDA_ROOT}/nvvm/libdevice/libdevice.10.bc" --opt-level-codegen 3 --instance-size 10000000 --repeat 2 --grid-dim-x 4096 --block-dim-x 256
```
The above command executes the benchmark on a GPU with `Compute Architecture` `sm_80` and links the generated code to the `libdevice` optimized math library provided by `NVIDIA`.
Using the above command you can also select the optimization level of the generated code, the instance size of the generated data, the number of repetitions and the grid and block dimensions for the GPU execution.

**Note**: In order for the CUDA backend to be able to compile and execute the generated code on GPU the CUDA Toolkit version installed needs to have the same version as the `CUDA` installed by the NVIDIA driver in the system that will be used to run the benchmark.
You can find the CUDA Toolkit version by running the following command:
```
nvidia-smi
```
and noting the `CUDA Version` stated there. For example if `CUDA Version` reported by `nvidia-smi` is CUDA 11.4 you need to install the `CUDA Toolkit 11.4.*` to be able to compile and execute the GPU code.

## Testing the Installed Module

Expand Down
3 changes: 0 additions & 3 deletions src/codegen/codegen_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ struct CodeGenConfig {
/// true if cuda code to be generated
bool cuda_backend = false;

/// true if llvm code to be generated
bool llvm_backend = false;

/// true if sympy should be used for solving ODEs analytically
bool sympy_analytic = false;

Expand Down
18 changes: 11 additions & 7 deletions src/codegen/llvm/codegen_llvm_helper_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,6 @@ std::shared_ptr<ast::InstanceStruct> CodegenLLVMHelperVisitor::create_instance_s
add_var_with_type(VOLTAGE_VAR, FLOAT_TYPE, /*is_pointer=*/1);
add_var_with_type(NODE_INDEX_VAR, INTEGER_TYPE, /*is_pointer=*/1);

// add dt, t, celsius
add_var_with_type(naming::NTHREAD_T_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0);
add_var_with_type(naming::NTHREAD_DT_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0);
add_var_with_type(naming::CELSIUS_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0);
add_var_with_type(naming::SECOND_ORDER_VARIABLE, INTEGER_TYPE, /*is_pointer=*/0);
add_var_with_type(naming::MECH_NODECOUNT_VAR, INTEGER_TYPE, /*is_pointer=*/0);

// As we do not have `NrnThread` object as an argument, we store points to rhs
// and d to in the instance struct as well. Also need their respective shadow variables
// in case of point process mechanism.
Expand All @@ -256,6 +249,17 @@ std::shared_ptr<ast::InstanceStruct> CodegenLLVMHelperVisitor::create_instance_s
add_var_with_type(naming::NTHREAD_RHS_SHADOW, FLOAT_TYPE, /*is_pointer=*/1);
add_var_with_type(naming::NTHREAD_D_SHADOW, FLOAT_TYPE, /*is_pointer=*/1);

// NOTE: All the pointer variables should be declared before the scalar variables otherwise
// the allocation of memory for the variables in the InstanceStruct and their offsets will be
// wrong

// add dt, t, celsius
add_var_with_type(naming::NTHREAD_T_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0);
add_var_with_type(naming::NTHREAD_DT_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0);
add_var_with_type(naming::CELSIUS_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0);
add_var_with_type(naming::SECOND_ORDER_VARIABLE, INTEGER_TYPE, /*is_pointer=*/0);
add_var_with_type(naming::MECH_NODECOUNT_VAR, INTEGER_TYPE, /*is_pointer=*/0);

return std::make_shared<ast::InstanceStruct>(codegen_vars);
}

Expand Down
78 changes: 61 additions & 17 deletions src/codegen/llvm/codegen_llvm_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ namespace codegen {
/* Helper routines */
/****************************************************************************************/

static std::string get_wrapper_name(const std::string& kernel_name) {
return "__" + kernel_name + "_wrapper";
}

/// A utility to check for supported Statement AST nodes.
static bool is_supported_statement(const ast::Statement& statement) {
return statement.is_codegen_atomic_statement() || statement.is_codegen_for_statement() ||
Expand Down Expand Up @@ -55,15 +59,36 @@ static bool can_vectorize(const ast::CodegenForStatement& statement, symtab::Sym
return unsupported.empty() && supported.size() <= 1;
}

void CodegenLLVMVisitor::annotate_kernel_with_nvvm(llvm::Function* kernel) {
void CodegenLLVMVisitor::annotate_kernel_with_nvvm(llvm::Function* kernel,
const std::string& annotation = "kernel") {
llvm::Metadata* metadata[] = {llvm::ValueAsMetadata::get(kernel),
llvm::MDString::get(*context, "kernel"),
llvm::MDString::get(*context, annotation),
llvm::ValueAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), 1))};
llvm::MDNode* node = llvm::MDNode::get(*context, metadata);
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(node);
}

void CodegenLLVMVisitor::annotate_wrapper_kernels_with_nvvm() {
// First clear all the nvvm annotations from the module
auto module_named_metadata = module->getNamedMetadata("nvvm.annotations");
module->eraseNamedMetadata(module_named_metadata);

// Then each kernel should be annotated as "device" function and wrappers should be annotated as
// "kernel" functions
std::vector<std::string> kernel_names;
find_kernel_names(kernel_names);

for (const auto& kernel_name: kernel_names) {
// Get the kernel function.
auto kernel = module->getFunction(kernel_name);
// Get the kernel wrapper function.
auto kernel_wrapper = module->getFunction(get_wrapper_name(kernel_name));
annotate_kernel_with_nvvm(kernel, "device");
annotate_kernel_with_nvvm(kernel_wrapper, "kernel");
}
}

llvm::Value* CodegenLLVMVisitor::accept_and_get(const std::shared_ptr<ast::Node>& node) {
node->accept(*this);
return ir_builder.pop_last_value();
Expand Down Expand Up @@ -402,12 +427,17 @@ void CodegenLLVMVisitor::wrap_kernel_functions() {
auto kernel = module->getFunction(kernel_name);

// Create a wrapper void function that takes a void pointer as a single argument.
llvm::Type* i32_type = ir_builder.get_i32_type();
llvm::Type* return_type;
if (platform.is_gpu()) {
return_type = ir_builder.get_void_type();
} else {
return_type = ir_builder.get_i32_type();
}
llvm::Type* void_ptr_type = ir_builder.get_i8_ptr_type();
llvm::Function* wrapper_func = llvm::Function::Create(
llvm::FunctionType::get(i32_type, {void_ptr_type}, /*isVarArg=*/false),
llvm::FunctionType::get(return_type, {void_ptr_type}, /*isVarArg=*/false),
llvm::Function::ExternalLinkage,
"__" + kernel_name + "_wrapper",
get_wrapper_name(kernel_name),
*module);

// Optionally, add debug information for the wrapper function.
Expand All @@ -425,9 +455,23 @@ void CodegenLLVMVisitor::wrap_kernel_functions() {
args.push_back(bitcasted);
ir_builder.create_function_call(kernel, args, /*use_result=*/false);

// Create a 0 return value and a return instruction.
ir_builder.create_i32_constant(0);
ir_builder.create_return(ir_builder.pop_last_value());
// create return instructions and annotate wrapper with certain attributes depending on
// the backend type
if (platform.is_gpu()) {
// return void
ir_builder.create_return();
} else {
// Create a 0 return value and a return instruction.
ir_builder.create_i32_constant(0);
ir_builder.create_return(ir_builder.pop_last_value());
ir_builder.set_function(wrapper_func);
ir_builder.set_kernel_attributes();
}
ir_builder.clear_function();
}
// for GPU we need to first clear all the annotations and then reapply them
if (platform.is_gpu()) {
annotate_wrapper_kernels_with_nvvm();
}
}

Expand Down Expand Up @@ -823,9 +867,6 @@ void CodegenLLVMVisitor::visit_program(const ast::Program& node) {

// Handle GPU optimizations (CUDA platfroms only for now).
if (platform.is_gpu()) {
if (!platform.is_CUDA_gpu())
throw std::runtime_error("Error: unsupported GPU architecture!\n");

// We only support CUDA backends anyway, so this works for now.
utils::initialise_nvptx_passes();

Expand All @@ -839,15 +880,12 @@ void CodegenLLVMVisitor::visit_program(const ast::Program& node) {
logger->debug("Dumping generated IR...\n" + dump_module());
}

// If the output directory is specified, save the IR to .ll file.
if (output_dir != ".") {
utils::save_ir_to_ll_file(*module, output_dir + "/" + mod_filename);
}

// Setup CodegenHelper for C++ wrapper file
setup(node);
// Print C++ wrapper file
print_wrapper_routines();
print_target_file();
// Print LLVM IR module to <mod_filename>.ll file
utils::save_ir_to_ll_file(*module, output_dir + "/" + mod_filename);
}

void CodegenLLVMVisitor::print_mechanism_range_var_structure() {
Expand Down Expand Up @@ -960,6 +998,12 @@ void CodegenLLVMVisitor::print_instance_variable_setup() {
// Pass ml->nodeindices pointer to node_index
printer->add_line("inst->node_index = ml->nodeindices;");

// Setup rhs, d and their shadow vectors
printer->add_line(fmt::format("inst->{} = nt->_actual_rhs;", naming::NTHREAD_RHS));
printer->add_line(fmt::format("inst->{} = nt->_actual_d;", naming::NTHREAD_D));
printer->add_line(fmt::format("inst->{} = nt->_shadow_rhs;", naming::NTHREAD_RHS_SHADOW));
printer->add_line(fmt::format("inst->{} = nt->_shadow_d;", naming::NTHREAD_D_SHADOW));

// Setup global variables
printer->add_line("inst->{0} = nt->{0};"_format(naming::NTHREAD_T_VARIABLE));
printer->add_line("inst->{0} = nt->{0};"_format(naming::NTHREAD_DT_VARIABLE));
Expand Down
12 changes: 6 additions & 6 deletions src/codegen/llvm/codegen_llvm_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ class CodegenLLVMVisitor: public CodegenCVisitor {
return str;
}

void print_target_file() const {
target_printer->add_multi_line(dump_module());
}

/// Fills the container with the names of kernel functions from the MOD file.
void find_kernel_names(std::vector<std::string>& container);

Expand Down Expand Up @@ -303,8 +299,12 @@ class CodegenLLVMVisitor: public CodegenCVisitor {
void print_compute_functions() override;

private:
// Annotates kernel function with NVVM metadata.
void annotate_kernel_with_nvvm(llvm::Function* kernel);
/// Annotates kernel function with NVVM metadata.
void annotate_kernel_with_nvvm(llvm::Function* kernel, const std::string& annotation);

/// Handles NVVM function annotations when we create the wrapper functions. All original kernels
/// should be "device" functions and wrappers "kernel" functions
void annotate_wrapper_kernels_with_nvvm();

/// Accepts the given AST node and returns the processed value.
llvm::Value* accept_and_get(const std::shared_ptr<ast::Node>& node);
Expand Down
35 changes: 24 additions & 11 deletions src/codegen/llvm/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,9 @@ void initialise_nvptx_passes() {
initialise_optimisation_passes();
}

void optimise_module_for_nvptx(codegen::Platform& platform,
llvm::Module& module,
int opt_level,
std::string& target_asm) {
std::unique_ptr<llvm::TargetMachine> create_CUDA_target_machine(const codegen::Platform& platform,
llvm::Module& module) {
// CUDA target machine we generating code for.
std::unique_ptr<llvm::TargetMachine> tm;
std::string platform_name = platform.get_name();

// Target and layout information.
Expand Down Expand Up @@ -111,9 +108,30 @@ void optimise_module_for_nvptx(codegen::Platform& platform,
if (!target)
throw std::runtime_error("Error: " + error_msg + "\n");

std::unique_ptr<llvm::TargetMachine> tm;
tm.reset(target->createTargetMachine(triple, subtarget, features, {}, {}));
if (!tm)
throw std::runtime_error("Error: creating target machine failed! Aborting.");
return tm;
}

std::string get_module_ptx(llvm::TargetMachine& tm, llvm::Module& module) {
std::string target_asm;
llvm::raw_string_ostream stream(target_asm);
llvm::buffer_ostream pstream(stream);
llvm::legacy::PassManager codegen_pm;

tm.addPassesToEmitFile(codegen_pm, pstream, nullptr, llvm::CGFT_AssemblyFile);
codegen_pm.run(module);
return target_asm;
}

void optimise_module_for_nvptx(const codegen::Platform& platform,
llvm::Module& module,
int opt_level,
std::string& target_asm) {
// Create target machine for CUDA GPU
auto tm = create_CUDA_target_machine(platform, module);

// Create pass managers.
llvm::legacy::FunctionPassManager func_pm(&module);
Expand All @@ -137,12 +155,7 @@ void optimise_module_for_nvptx(codegen::Platform& platform,

// Now, we want to run target-specific (e.g. NVPTX) passes. In LLVM, this
// is done via `addPassesToEmitFile`.
llvm::raw_string_ostream stream(target_asm);
llvm::buffer_ostream pstream(stream);
llvm::legacy::PassManager codegen_pm;

tm->addPassesToEmitFile(codegen_pm, pstream, nullptr, llvm::CGFT_AssemblyFile);
codegen_pm.run(module);
target_asm = get_module_ptx(*tm, module);
}

void initialise_optimisation_passes() {
Expand Down
9 changes: 8 additions & 1 deletion src/codegen/llvm/llvm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@ void initialise_optimisation_passes();
/// Initialises NVPTX-specific optimisation passes.
void initialise_nvptx_passes();

//// Initializes a CUDA target machine
std::unique_ptr<llvm::TargetMachine> create_CUDA_target_machine(const codegen::Platform& platform,
llvm::Module& module);

/// Generate PTX code given a CUDA target machine and the module
std::string get_module_ptx(llvm::TargetMachine& tm, llvm::Module& module);

/// Replaces calls to LLVM intrinsics with appropriate library calls.
void replace_with_lib_functions(codegen::Platform& platform, llvm::Module& module);

/// Optimises the given LLVM IR module for NVPTX targets.
void optimise_module_for_nvptx(codegen::Platform& platform,
void optimise_module_for_nvptx(const codegen::Platform& platform,
llvm::Module& module,
int opt_level,
std::string& target_asm);
Expand Down
Loading

0 comments on commit be30888

Please sign in to comment.