From acd19e00635885cdff2c1d8e6ac576f6bb1838dd Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 12 Nov 2024 14:53:11 -0800 Subject: [PATCH] Wean ao off of PYBIND [part 1] --- packaging/post_build_script.sh | 24 ++++++++++++------------ setup.py | 3 +++ torchao/__init__.py | 6 ++++-- torchao/csrc/init.cpp | 3 --- 4 files changed, 19 insertions(+), 17 deletions(-) delete mode 100644 torchao/csrc/init.cpp diff --git a/packaging/post_build_script.sh b/packaging/post_build_script.sh index 611931f7d3..70e8d83392 100644 --- a/packaging/post_build_script.sh +++ b/packaging/post_build_script.sh @@ -7,15 +7,14 @@ set -eux -WHEEL_NAME=$(ls dist/) +# Prepare manywheel, only for CUDA. +# The wheel is a pure python wheel for other platforms. +if [[ "$CU_VERSION" == cu* ]]; then + WHEEL_NAME=$(ls dist/) -pushd dist -# Prepare manywheel -manylinux_plat=manylinux2014_x86_64 -if [[ "$CU_VERSION" == "xpu" ]]; then - manylinux_plat=manylinux_2_28_x86_64 -fi -auditwheel repair --plat "$manylinux_plat" -w . \ + pushd dist + manylinux_plat=manylinux2014_x86_64 + auditwheel repair --plat "$manylinux_plat" -w . \ --exclude libtorch.so \ --exclude libtorch_python.so \ --exclude libtorch_cuda.so \ @@ -26,10 +25,11 @@ auditwheel repair --plat "$manylinux_plat" -w . \ --exclude libcudart.so.11.0 \ "${WHEEL_NAME}" -ls -lah . -# Clean up the linux_x86_64 wheel -rm "${WHEEL_NAME}" -popd + ls -lah . + # Clean up the linux_x86_64 wheel + rm "${WHEEL_NAME}" + popd +fi MANYWHEEL_NAME=$(ls dist/) # Try to install the new wheel diff --git a/setup.py b/setup.py index 229e18eec6..b7334631a8 100644 --- a/setup.py +++ b/setup.py @@ -110,6 +110,9 @@ def get_extensions(): if use_cuda: sources += cuda_sources + if len(sources) == 0: + return None + ext_modules = [ extension( "torchao._C", diff --git a/torchao/__init__.py b/torchao/__init__.py index 0d1230ba93..b910af3d7e 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -22,10 +22,12 @@ ) if not _IS_FBCODE: try: - from . import _C + from pathlib import Path + so_files = list(Path(__file__).parent.glob("_C*.so")) + assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" + torch.ops.load_library(so_files[0]) from . import ops except: - _C = None logging.info("Skipping import of cpp extensions") from torchao.quantization import ( diff --git a/torchao/csrc/init.cpp b/torchao/csrc/init.cpp deleted file mode 100644 index cb2ec42a45..0000000000 --- a/torchao/csrc/init.cpp +++ /dev/null @@ -1,3 +0,0 @@ -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}