diff --git a/setup.py b/setup.py index 229e18eec6..9a926ced7a 100644 --- a/setup.py +++ b/setup.py @@ -114,6 +114,7 @@ def get_extensions(): extension( "torchao._C", sources, + py_limited_api=True, extra_compile_args=extra_compile_args, extra_link_args=extra_link_args, ) @@ -136,4 +137,7 @@ def get_extensions(): long_description_content_type="text/markdown", url="https://github.com/pytorch-labs/ao", cmdclass={"build_ext": BuildExtension}, + options={"bdist_wheel": { + "py_limited_api": "cp38" + }}, ) diff --git a/torchao/__init__.py b/torchao/__init__.py index 0d1230ba93..2fe25b8ea8 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -22,10 +22,14 @@ ) if not _IS_FBCODE: try: - from . import _C + from importlib.util import find_spec + from pathlib import Path + spec = find_spec("torchao") + assert spec is not None, "torchao python module spec is unexpectedly None" + SO_PATH = Path(spec.origin).parent / "_C.abi3.so" + torch.ops.load_library(SO_PATH) from . import ops except: - _C = None logging.info("Skipping import of cpp extensions") from torchao.quantization import ( diff --git a/torchao/csrc/init.cpp b/torchao/csrc/init.cpp deleted file mode 100644 index cb2ec42a45..0000000000 --- a/torchao/csrc/init.cpp +++ /dev/null @@ -1,3 +0,0 @@ -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}