diff --git a/setup.py b/setup.py index 7f4bbd668d..8232caa254 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ import subprocess from datetime import datetime -from setuptools import find_packages, setup +from setuptools import Extension, find_packages, setup current_date = datetime.now().strftime("%Y%m%d") @@ -41,6 +41,14 @@ def read_version(file_path="version.txt"): use_cpp = os.getenv("USE_CPP") +import platform + +build_torchao_experimental = ( + use_cpp == "1" + and platform.machine().startswith("arm64") + and platform.system() == "Darwin" +) + version_prefix = read_version() # Version is version.dev year month date if using nightlies and version if not version = ( @@ -49,6 +57,11 @@ def read_version(file_path="version.txt"): else version_prefix ) + +def use_debug_mode(): + return os.getenv("DEBUG", "0") == "1" + + import torch from torch.utils.cpp_extension import ( CUDA_HOME, @@ -59,8 +72,61 @@ def read_version(file_path="version.txt"): ) +# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext +class TorchAOBuildExt(BuildExtension): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def build_extensions(self): + cmake_extensions = [ + ext for ext in self.extensions if isinstance(ext, CMakeExtension) + ] + other_extensions = [ + ext for ext in self.extensions if not isinstance(ext, CMakeExtension) + ] + for ext in cmake_extensions: + self.build_cmake(ext) + + # Use BuildExtension to build other extensions + self.extensions = other_extensions + super().build_extensions() + + self.extensions = other_extensions + cmake_extensions + + def build_cmake(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + + build_type = "Debug" if use_debug_mode() else "Release" + + from distutils.sysconfig import get_python_lib + + torch_dir = get_python_lib() + "/torch/share/cmake/Torch" + + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + + subprocess.check_call( + [ + "cmake", + ext.sourcedir, + "-DCMAKE_BUILD_TYPE=" + build_type, + "-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF", + "-DTorch_DIR=" + torch_dir, + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, + ], + cwd=self.build_temp, + ) + subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp) + + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + def get_extensions(): - debug_mode = os.getenv("DEBUG", "0") == "1" + debug_mode = use_debug_mode() if debug_mode: print("Compiling in debug mode") @@ -129,18 +195,25 @@ def get_extensions(): if use_cuda: sources += cuda_sources - if len(sources) == 0: - return None + ext_modules = [] + if len(sources) > 0: + ext_modules.append( + extension( + "torchao._C", + sources, + py_limited_api=True, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ) - ext_modules = [ - extension( - "torchao._C", - sources, - py_limited_api=True, - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, + if build_torchao_experimental: + ext_modules.append( + CMakeExtension( + "torchao.experimental", + sourcedir="torchao/experimental", + ) ) - ] return ext_modules @@ -159,6 +232,6 @@ def get_extensions(): long_description=open("README.md").read(), long_description_content_type="text/markdown", url="https://github.com/pytorch/ao", - cmdclass={"build_ext": BuildExtension}, + cmdclass={"build_ext": TorchAOBuildExt}, options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) diff --git a/torchao/__init__.py b/torchao/__init__.py index 3e00bf6c58..c6048d4328 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -32,6 +32,18 @@ assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" torch.ops.load_library(so_files[0]) from . import ops + + # The following library contains CPU kernels from torchao/experimental + # They are built automatically by ao/setup.py if on an ARM machine. + # They can also be built outside of the torchao install process by + # running the script `torchao/experimental/build_torchao_ops.sh ` + # For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md + experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*")) + if len(experimental_lib) > 0: + assert ( + len(experimental_lib) == 1 + ), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}" + torch.ops.load_library(experimental_lib[0]) except: logging.debug("Skipping import of cpp extensions")