Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchao setup.py with cmake #1490

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 86 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import subprocess
from datetime import datetime

from setuptools import find_packages, setup
from setuptools import Extension, find_packages, setup

current_date = datetime.now().strftime("%Y%m%d")

Expand Down Expand Up @@ -41,6 +41,14 @@ def read_version(file_path="version.txt"):

use_cpp = os.getenv("USE_CPP")

import platform

build_torchao_experimental = (
use_cpp == "1"
and platform.machine().startswith("arm64")
and platform.system() == "Darwin"
)

version_prefix = read_version()
# Version is version.dev year month date if using nightlies and version if not
version = (
Expand All @@ -49,6 +57,11 @@ def read_version(file_path="version.txt"):
else version_prefix
)


def use_debug_mode():
return os.getenv("DEBUG", "0") == "1"


import torch
from torch.utils.cpp_extension import (
CUDA_HOME,
Expand All @@ -59,8 +72,61 @@ def read_version(file_path="version.txt"):
)


# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
class TorchAOBuildExt(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def build_extensions(self):
cmake_extensions = [
ext for ext in self.extensions if isinstance(ext, CMakeExtension)
]
other_extensions = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
for ext in cmake_extensions:
self.build_cmake(ext)

# Use BuildExtension to build other extensions
self.extensions = other_extensions
super().build_extensions()

self.extensions = other_extensions + cmake_extensions

def build_cmake(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))

build_type = "Debug" if use_debug_mode() else "Release"

from distutils.sysconfig import get_python_lib

torch_dir = get_python_lib() + "/torch/share/cmake/Torch"

if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

subprocess.check_call(
[
"cmake",
ext.sourcedir,
"-DCMAKE_BUILD_TYPE=" + build_type,
"-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF",
"-DTorch_DIR=" + torch_dir,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
],
cwd=self.build_temp,
)
subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp)


class CMakeExtension(Extension):
def __init__(self, name, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)


def get_extensions():
debug_mode = os.getenv("DEBUG", "0") == "1"
debug_mode = use_debug_mode()
if debug_mode:
print("Compiling in debug mode")

Expand Down Expand Up @@ -129,18 +195,25 @@ def get_extensions():
if use_cuda:
sources += cuda_sources

if len(sources) == 0:
return None
ext_modules = []
if len(sources) > 0:
ext_modules.append(
extension(
"torchao._C",
sources,
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
)

ext_modules = [
extension(
"torchao._C",
sources,
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
if build_torchao_experimental:
ext_modules.append(
CMakeExtension(
"torchao.experimental",
sourcedir="torchao/experimental",
)
)
]

return ext_modules

Expand All @@ -159,6 +232,6 @@ def get_extensions():
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch/ao",
cmdclass={"build_ext": BuildExtension},
cmdclass={"build_ext": TorchAOBuildExt},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)
12 changes: 12 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
from . import ops

# The following library contains CPU kernels from torchao/experimental
# They are built automatically by ao/setup.py if on an ARM machine.
# They can also be built outside of the torchao install process by
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
if len(experimental_lib) > 0:
assert (
len(experimental_lib) == 1
), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
torch.ops.load_library(experimental_lib[0])
except:
logging.debug("Skipping import of cpp extensions")

Expand Down
Loading