Skip to content

Commit

Permalink
torchao setup.py with cmake
Browse files Browse the repository at this point in the history
Differential Revision: D67777662

Pull Request resolved: #1490
  • Loading branch information
metascroy authored Jan 10, 2025
1 parent cedadc7 commit 9c2635b
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 13 deletions.
99 changes: 86 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import subprocess
from datetime import datetime

from setuptools import find_packages, setup
from setuptools import Extension, find_packages, setup

current_date = datetime.now().strftime("%Y%m%d")

Expand Down Expand Up @@ -41,6 +41,14 @@ def read_version(file_path="version.txt"):

use_cpp = os.getenv("USE_CPP")

import platform

build_torchao_experimental = (
use_cpp == "1"
and platform.machine().startswith("arm64")
and platform.system() == "Darwin"
)

version_prefix = read_version()
# Version is version.dev year month date if using nightlies and version if not
version = (
Expand All @@ -49,6 +57,11 @@ def read_version(file_path="version.txt"):
else version_prefix
)


def use_debug_mode():
return os.getenv("DEBUG", "0") == "1"


import torch
from torch.utils.cpp_extension import (
CUDA_HOME,
Expand All @@ -59,8 +72,61 @@ def read_version(file_path="version.txt"):
)


# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
class TorchAOBuildExt(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def build_extensions(self):
cmake_extensions = [
ext for ext in self.extensions if isinstance(ext, CMakeExtension)
]
other_extensions = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
for ext in cmake_extensions:
self.build_cmake(ext)

# Use BuildExtension to build other extensions
self.extensions = other_extensions
super().build_extensions()

self.extensions = other_extensions + cmake_extensions

def build_cmake(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))

build_type = "Debug" if use_debug_mode() else "Release"

from distutils.sysconfig import get_python_lib

torch_dir = get_python_lib() + "/torch/share/cmake/Torch"

if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

subprocess.check_call(
[
"cmake",
ext.sourcedir,
"-DCMAKE_BUILD_TYPE=" + build_type,
"-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF",
"-DTorch_DIR=" + torch_dir,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
],
cwd=self.build_temp,
)
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)


class CMakeExtension(Extension):
def __init__(self, name, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)


def get_extensions():
debug_mode = os.getenv("DEBUG", "0") == "1"
debug_mode = use_debug_mode()
if debug_mode:
print("Compiling in debug mode")

Expand Down Expand Up @@ -129,18 +195,25 @@ def get_extensions():
if use_cuda:
sources += cuda_sources

if len(sources) == 0:
return None
ext_modules = []
if len(sources) > 0:
ext_modules.append(
extension(
"torchao._C",
sources,
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
)

ext_modules = [
extension(
"torchao._C",
sources,
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
if build_torchao_experimental:
ext_modules.append(
CMakeExtension(
"torchao.experimental",
sourcedir="torchao/experimental",
)
)
]

return ext_modules

Expand All @@ -159,6 +232,6 @@ def get_extensions():
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch/ao",
cmdclass={"build_ext": BuildExtension},
cmdclass={"build_ext": TorchAOBuildExt},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)
12 changes: 12 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
from . import ops

# The following library contains CPU kernels from torchao/experimental
# They are built automatically by ao/setup.py if on an ARM machine.
# They can also be built outside of the torchao install process by
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
if len(experimental_lib) > 0:
assert (
len(experimental_lib) == 1
), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
torch.ops.load_library(experimental_lib[0])
except:
logging.debug("Skipping import of cpp extensions")

Expand Down

0 comments on commit 9c2635b

Please sign in to comment.