Skip to content

Commit

Permalink
torchao setup.py with cmake (#1490)
Browse files Browse the repository at this point in the history
Summary:

Initial draft of using cmake in torchao's build process.

Install torchao with:

```
USE_CPP=1 pip install .
```

If on an arm64 machine, it builds the dynamic library for torchao at site-packages/torchao/libtorchao_ops_aten.dylib.  On import of torchao, if this library is found it is loaded.

Reviewed By: kimishpatel, drisspg

Differential Revision: D67777662
  • Loading branch information
metascroy authored and facebook-github-bot committed Jan 10, 2025
1 parent 982141b commit 16c85e8
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 8 deletions.
76 changes: 68 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import subprocess
from datetime import datetime

from setuptools import find_packages, setup
from setuptools import find_packages, setup, Extension

current_date = datetime.now().strftime("%Y%m%d")

Expand Down Expand Up @@ -41,6 +41,13 @@ def read_version(file_path="version.txt"):

use_cpp = os.getenv("USE_CPP")

import platform
build_torchao_experimental = (
use_cpp == "1" and
platform.machine().startswith("arm64") and
platform.system() == "Darwin"
)

version_prefix = read_version()
# Version is version.dev year month date if using nightlies and version if not
version = (
Expand All @@ -49,6 +56,9 @@ def read_version(file_path="version.txt"):
else version_prefix
)

def use_debug_mode():
return os.getenv('DEBUG', '0') == '1'

import torch
from torch.utils.cpp_extension import (
CUDA_HOME,
Expand All @@ -59,8 +69,51 @@ def read_version(file_path="version.txt"):
)


# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
class TorchAOBuildExt(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def build_extensions(self):
cmake_extensions = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)]
other_extensions = [ext for ext in self.extensions if not isinstance(ext, CMakeExtension)]
for ext in cmake_extensions:
self.build_cmake(ext)

# Use BuildExtension to build other extensions
self.extensions = other_extensions
super().build_extensions()

self.extensions = other_extensions + cmake_extensions

def build_cmake(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))

build_type = "Debug" if use_debug_mode() else "Release"

from distutils.sysconfig import get_python_lib
torch_dir = get_python_lib() + "/torch/share/cmake/Torch"

if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

subprocess.check_call(
['cmake', ext.sourcedir, '-DCMAKE_BUILD_TYPE=' + build_type, '-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF', '-DTorch_DIR=' + torch_dir, '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir],
cwd=self.build_temp
)
subprocess.check_call(
['cmake', '--build', '.'],
cwd=self.build_temp
)

class CMakeExtension(Extension):
def __init__(self, name, sourcedir=''):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)


def get_extensions():
debug_mode = os.getenv("DEBUG", "0") == "1"
debug_mode = use_debug_mode()
if debug_mode:
print("Compiling in debug mode")

Expand Down Expand Up @@ -129,18 +182,25 @@ def get_extensions():
if use_cuda:
sources += cuda_sources

if len(sources) == 0:
return None

ext_modules = [
ext_modules = []
if len(sources) > 0:
ext_modules.append(
extension(
"torchao._C",
sources,
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]
)

if build_torchao_experimental:
ext_modules.append(
CMakeExtension(
"torchao.experimental",
sourcedir="torchao/experimental",
)
)

return ext_modules

Expand All @@ -159,6 +219,6 @@ def get_extensions():
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch/ao",
cmdclass={"build_ext": BuildExtension},
cmdclass={"build_ext": TorchAOBuildExt},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)
10 changes: 10 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
from . import ops

# The following library contains CPU kernels from torchao/experimental
# They are built automatically by ao/setup.py if on an ARM machine.
# They can also be built outside of the torchao install process by
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
if len(experimental_lib) > 0:
assert len(experimental_lib) == 1, f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
torch.ops.load_library(experimental_lib[0])
except:
logging.debug("Skipping import of cpp extensions")

Expand Down

0 comments on commit 16c85e8

Please sign in to comment.