Skip to content

Commit

Permalink
torchao setup.py with cmake
Browse files Browse the repository at this point in the history
Summary:
Initial draft of using cmake in torchao's build process.

Install torchao with:

```
USE_CPP=1 pip install .
```

If on an arm64 machine, it builds the dynamic library for torchao at site-packages/torchao/libtorchao_ops_aten.dylib.  On import of torchao, if this library is found it is loaded.

Differential Revision: D67777662
  • Loading branch information
metascroy authored and facebook-github-bot committed Jan 3, 2025
1 parent c59bce5 commit e4f77da
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
65 changes: 58 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime
import subprocess

from setuptools import find_packages, setup
from setuptools import find_packages, setup, Extension

current_date = datetime.now().strftime("%Y%m%d")

Expand All @@ -35,6 +35,12 @@ def read_version(file_path="version.txt"):

use_cpp = os.getenv('USE_CPP')

import platform
build_torchao_experimental = (
use_cpp == "1" and
platform.machine().startswith("arm64")
)

version_prefix = read_version()
# Version is version.dev year month date if using nightlies and version if not
version = f"{version_prefix}.dev{current_date}" if os.environ.get("TORCHAO_NIGHTLY") else version_prefix
Expand All @@ -50,6 +56,44 @@ def read_version(file_path="version.txt"):
)


# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
class TorchAOBuildExt(BuildExtension):
def build_extensions(self):
cmake_extensions = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)]
other_extensions = [ext for ext in self.extensions if not isinstance(ext, CMakeExtension)]

for ext in cmake_extensions:
self.build_cmake(ext)
for ext in other_extensions:
self.build_other(ext)

def build_cmake(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))

from distutils.sysconfig import get_python_lib
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"

if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

subprocess.check_call(
['cmake', ext.sourcedir, '-DCMAKE_BUILD_TYPE=Release', '-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF', '-DTorch_DIR=' + torch_dir, '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir],
cwd=self.build_temp
)
subprocess.check_call(
['cmake', '--build', '.'],
cwd=self.build_temp
)

def build_other(self, ext):
super().build_extension(ext)

class CMakeExtension(Extension):
def __init__(self, name, sourcedir=''):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)


def get_extensions():
debug_mode = os.getenv('DEBUG', '0') == '1'
if debug_mode:
Expand Down Expand Up @@ -103,18 +147,25 @@ def get_extensions():
if use_cuda:
sources += cuda_sources

if len(sources) == 0:
return None

ext_modules = [
ext_modules = []
if len(sources) > 0:
ext_modules.append(
extension(
"torchao._C",
sources,
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]
)

if build_torchao_experimental:
ext_modules.append(
CMakeExtension(
"torchao.experimental",
sourcedir="torchao/experimental",
)
)

return ext_modules

Expand All @@ -133,7 +184,7 @@ def get_extensions():
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch/ao",
cmdclass={"build_ext": BuildExtension},
cmdclass={"build_ext": TorchAOBuildExt},
options={"bdist_wheel": {
"py_limited_api": "cp39"
}},
Expand Down
5 changes: 5 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
from . import ops

experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
if len(experimental_lib) > 0:
assert len(so_files) == 1, f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
torch.ops.load_library(experimental_lib[0])
except:
logging.debug("Skipping import of cpp extensions")

Expand Down

0 comments on commit e4f77da

Please sign in to comment.