Skip to content

Commit

Permalink
fixing pykeops imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Nizben committed Nov 18, 2024
1 parent 4ae543c commit dcd7638
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 27 deletions.
4 changes: 3 additions & 1 deletion keopscore/keopscore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,6 @@

# Retrieve the current build folder
build_folder = config.get_build_folder()
# from keopscore.config.config import show_gpu_config, show_cuda_status
from keopscore.config import cuda_config
show_gpu_config = cuda_config.print_all()
show_cuda_status = cuda_config.get_use_cuda()
4 changes: 0 additions & 4 deletions keopscore/keopscore/config/Platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,3 @@ def print_all(self):
print(f"{var} is not set")


if __name__ == "__main__":
# Create an instance of DetectPlatform and print all platform related information
platform_detector = DetectPlatform()
platform_detector.print_all()
5 changes: 1 addition & 4 deletions keopscore/keopscore/config/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,4 @@ def print_all(self):
print(f"{var} is not set")


if __name__ == "__main__":
# Create an instance of CUDAConfig and print all CUDA-related information
cuda_config = CUDAConfig()
cuda_config.print_all()

4 changes: 0 additions & 4 deletions keopscore/keopscore/config/openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,3 @@ def print_all(self):
print(f"{var} is not set")


if __name__ == "__main__":
# Create an instance of OpenMPConfig and print all OpenMP-related information
openmp_config = OpenMPConfig()
openmp_config.print_all()
16 changes: 8 additions & 8 deletions pykeops/pykeops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

import keopscore
import keopscore.config
import keopscore.config.config
from keopscore.config.base_config import ConfigNew
from keopscore.config import config, cuda_config

from . import config as pykeopsconfig

from keopscore import show_cuda_status

keops_get_build_folder = ConfigNew.get_default_build_folder_name()
keops_get_build_folder = pykeopsconfig.get_build_folder
from .config import pykeops_nvrtc_name
from .config import numpy_found, torch_found


def set_verbose(val):
Expand All @@ -39,8 +39,8 @@ def set_verbose(val):

default_device_id = 0 # default Gpu device number

if keopscore.config.config.use_cuda:
if not os.path.exists(pykeopsconfig.pykeops_nvrtc_name(type="target")):
if cuda_config.get_use_cuda():
if not os.path.exists(pykeops_nvrtc_name(type="target")):
from .common.keops_io.LoadKeOps_nvrtc import compile_jit_binary

compile_jit_binary()
Expand Down Expand Up @@ -74,10 +74,10 @@ def get_build_folder():
return keops_get_build_folder()


if pykeopsconfig.numpy_found:
if numpy_found:
from .numpy.test_install import test_numpy_bindings

if pykeopsconfig.torch_found:
if torch_found:
from .torch.test_install import test_torch_bindings

# next line is to ensure that cache file for formulas is loaded at import
Expand Down
4 changes: 2 additions & 2 deletions pykeops/pykeops/common/keops_io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import keopscore.config
from keopscore.config import config, cuda_config

if keopscore.config.config.use_cuda:
if cuda_config._use_cuda:
from . import LoadKeOps_nvrtc, LoadKeOps_cpp

keops_binder = {
Expand Down
7 changes: 3 additions & 4 deletions pykeops/pykeops/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
numpy_found = importlib.util.find_spec("numpy") is not None
torch_found = importlib.util.find_spec("torch") is not None

from keopscore.config.cuda import CUDAconfig
from keopscore.config.base_config import ConfigNew
from keopscore.config import cuda_config, config

gpu_available = CUDAconfig.get_use_cuda()
get_build_folder = ConfigNew.get_build_folder()
gpu_available = cuda_config.get_use_cuda()
get_build_folder = config.get_build_folder


def pykeops_nvrtc_name(type="src"):
Expand Down

0 comments on commit dcd7638

Please sign in to comment.