Skip to content

Commit

Permalink
fixing import
Browse files Browse the repository at this point in the history
  • Loading branch information
Nizben committed Nov 7, 2024
1 parent 854d8e4 commit 45d1191
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
12 changes: 7 additions & 5 deletions keopscore/keopscore/binders/nvrtc/Gpu_link_compile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from os.path import join
from ctypes import create_string_buffer, CDLL, c_int
from os import RTLD_LAZY
import sysconfig
Expand Down Expand Up @@ -30,10 +31,11 @@
compile_options
+ f" -fpermissive -L{cuda_config.get_libcuda_folder()} -L{cuda_config.get_libnvrtc_folder()} -lcuda -lnvrtc"
)
# add nvrtc_include
# add jit_source_file,
# add cuda_available,
# add get_build_folder,
nvrtc_include = " -I" + base_config.get_bindings_source_dir() + " -I" + cuda_config.get_cuda_include_path()
jit_source_file = join(base_config.get_base_dir_path(), "binders", "nvrtc", "keops_nvrtc.cpp")
jit_source_header = join(base_config.get_base_dir_path(), "binders", "nvrtc", "keops_nvrtc.h")
cuda_available = cuda_config.get_use_cuda()
get_build_folder = base_config.get_build_folder()

from keopscore.utils.misc_utils import KeOps_Error, KeOps_Message, KeOps_OS_Run
from keopscore.utils.gpu_utils import get_gpu_props, custom_cuda_include_fp16_path
Expand All @@ -45,7 +47,7 @@

def jit_compile_dll():
return os.path.join(
get_build_folder(),
get_build_folder,
"nvrtc_jit" + sysconfig.get_config_var("SHLIB_SUFFIX"),
)

Expand Down
12 changes: 6 additions & 6 deletions keopscore/keopscore/config/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def print_specific_gpus(self):
print("Specific GPUs (CUDA_VISIBLE_DEVICES): Not Set")

def set_libcuda_folder(self):
"""Check if CUDA libraries are available, and then set libcuda_folder"""
"""Check if CUDA libraries are available, and then set libcuda_folder"""
cuda_lib = find_library("cuda")
nvrtc_lib = find_library("nvrtc")
if cuda_lib and nvrtc_lib:
Expand All @@ -85,7 +85,7 @@ def get_libcuda_folder(self):
return self.libcuda_folder

def set_libnvrtc_folder(self):
"""Check if CUDA libraries are available, and then set libnvrtc_folder"""
"""Check if CUDA libraries are available, and then set libnvrtc_folder"""
cuda_lib = find_library("cuda")
nvrtc_lib = find_library("nvrtc")
if cuda_lib and nvrtc_lib:
Expand Down Expand Up @@ -148,14 +148,14 @@ def get_cuda_include_path(self):
include_path = Path(path) / "include"
if (include_path / "cuda.h").is_file() and (include_path / "nvrtc.h").is_file():
self.cuda_include_path = str(include_path)
return
return self.cuda_include_path
# Check if CUDA is installed via conda
conda_prefix = os.getenv("CONDA_PREFIX")
if conda_prefix:
include_path = Path(conda_prefix) / "include"
if (include_path / "cuda.h").is_file() and (include_path / "nvrtc.h").is_file():
self.cuda_include_path = str(include_path)
return
return self.cuda_include_path
# Check standard locations
cuda_version_str = self.get_cuda_version(out_type="string")
possible_paths = [
Expand All @@ -167,14 +167,14 @@ def get_cuda_include_path(self):
include_path = base_path / "include"
if (include_path / "cuda.h").is_file() and (include_path / "nvrtc.h").is_file():
self.cuda_include_path = str(include_path)
return
return self.cuda_include_path
# Use get_include_file_abspath to locate headers
cuda_h_path = self.get_include_file_abspath("cuda.h")
nvrtc_h_path = self.get_include_file_abspath("nvrtc.h")
if cuda_h_path and nvrtc_h_path:
if os.path.dirname(cuda_h_path) == os.path.dirname(nvrtc_h_path):
self.cuda_include_path = os.path.dirname(cuda_h_path)
return
return self.cuda_include_path
# If not found, issue a warning
KeOps_Warning(
"CUDA include path not found. Please set the CUDA_PATH or CUDA_HOME environment variable."
Expand Down
3 changes: 2 additions & 1 deletion keopscore/keopscore/config/openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import subprocess
import platform
from ctypes.util import find_library
from base_config import ConfigNew
from base_config import ConfigNew

from keopscore.utils.misc_utils import KeOps_Warning

class OpenMPConfig(ConfigNew):
Expand Down

0 comments on commit 45d1191

Please sign in to comment.