diff --git a/keopscore/keopscore/binders/nvrtc/Gpu_link_compile.py b/keopscore/keopscore/binders/nvrtc/Gpu_link_compile.py index 3e230d51..f882544a 100644 --- a/keopscore/keopscore/binders/nvrtc/Gpu_link_compile.py +++ b/keopscore/keopscore/binders/nvrtc/Gpu_link_compile.py @@ -1,4 +1,5 @@ import os +from os.path import join from ctypes import create_string_buffer, CDLL, c_int from os import RTLD_LAZY import sysconfig @@ -30,10 +31,11 @@ compile_options + f" -fpermissive -L{cuda_config.get_libcuda_folder()} -L{cuda_config.get_libnvrtc_folder()} -lcuda -lnvrtc" ) -# add nvrtc_include -# add jit_source_file, -# add cuda_available, -# add get_build_folder, +nvrtc_include = " -I" + base_config.get_bindings_source_dir() + " -I" + cuda_config.get_cuda_include_path() +jit_source_file = join(base_config.get_base_dir_path(), "binders", "nvrtc", "keops_nvrtc.cpp") +jit_source_header = join(base_config.get_base_dir_path(), "binders", "nvrtc", "keops_nvrtc.h") +cuda_available = cuda_config.get_use_cuda() +get_build_folder = base_config.get_build_folder() from keopscore.utils.misc_utils import KeOps_Error, KeOps_Message, KeOps_OS_Run from keopscore.utils.gpu_utils import get_gpu_props, custom_cuda_include_fp16_path @@ -45,7 +47,7 @@ def jit_compile_dll(): return os.path.join( - get_build_folder(), + get_build_folder, "nvrtc_jit" + sysconfig.get_config_var("SHLIB_SUFFIX"), ) diff --git a/keopscore/keopscore/config/cuda.py b/keopscore/keopscore/config/cuda.py index 209c03f7..62fed421 100644 --- a/keopscore/keopscore/config/cuda.py +++ b/keopscore/keopscore/config/cuda.py @@ -75,7 +75,7 @@ def print_specific_gpus(self): print("Specific GPUs (CUDA_VISIBLE_DEVICES): Not Set") def set_libcuda_folder(self): - """Check if CUDA libraries are available, and then set libcuda_folder""" + """Check if CUDA libraries are available, and then set libcuda_folder""" cuda_lib = find_library("cuda") nvrtc_lib = find_library("nvrtc") if cuda_lib and nvrtc_lib: @@ -85,7 +85,7 @@ def get_libcuda_folder(self): return self.libcuda_folder def set_libnvrtc_folder(self): - """Check if CUDA libraries are available, and then set libnvrtc_folder""" + """Check if CUDA libraries are available, and then set libnvrtc_folder""" cuda_lib = find_library("cuda") nvrtc_lib = find_library("nvrtc") if cuda_lib and nvrtc_lib: @@ -148,14 +148,14 @@ def get_cuda_include_path(self): include_path = Path(path) / "include" if (include_path / "cuda.h").is_file() and (include_path / "nvrtc.h").is_file(): self.cuda_include_path = str(include_path) - return + return self.cuda_include_path # Check if CUDA is installed via conda conda_prefix = os.getenv("CONDA_PREFIX") if conda_prefix: include_path = Path(conda_prefix) / "include" if (include_path / "cuda.h").is_file() and (include_path / "nvrtc.h").is_file(): self.cuda_include_path = str(include_path) - return + return self.cuda_include_path # Check standard locations cuda_version_str = self.get_cuda_version(out_type="string") possible_paths = [ @@ -167,14 +167,14 @@ def get_cuda_include_path(self): include_path = base_path / "include" if (include_path / "cuda.h").is_file() and (include_path / "nvrtc.h").is_file(): self.cuda_include_path = str(include_path) - return + return self.cuda_include_path # Use get_include_file_abspath to locate headers cuda_h_path = self.get_include_file_abspath("cuda.h") nvrtc_h_path = self.get_include_file_abspath("nvrtc.h") if cuda_h_path and nvrtc_h_path: if os.path.dirname(cuda_h_path) == os.path.dirname(nvrtc_h_path): self.cuda_include_path = os.path.dirname(cuda_h_path) - return + return self.cuda_include_path # If not found, issue a warning KeOps_Warning( "CUDA include path not found. Please set the CUDA_PATH or CUDA_HOME environment variable." diff --git a/keopscore/keopscore/config/openmp.py b/keopscore/keopscore/config/openmp.py index 83dff76f..61af9b8e 100644 --- a/keopscore/keopscore/config/openmp.py +++ b/keopscore/keopscore/config/openmp.py @@ -4,7 +4,8 @@ import subprocess import platform from ctypes.util import find_library -from base_config import ConfigNew +from base_config import ConfigNew + from keopscore.utils.misc_utils import KeOps_Warning class OpenMPConfig(ConfigNew):