Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/2] Wean off of PYBIND in favor of torch.ops.load_library #1276

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions packaging/post_build_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

set -eux

WHEEL_NAME=$(ls dist/)
# Prepare manywheel, only for CUDA.
# The wheel is a pure python wheel for other platforms.
if [[ "$CU_VERSION" == cu* ]]; then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this also true for rocm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but it looks like the current .cu files are not built for rocm based on the use_cuda here https://github.com/pytorch/ao/blob/main/setup.py#L64

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! heads up to @petrex who has been working on adding custom hip kernel support

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janeyx99 This #1201 would build hip kernels.

WHEEL_NAME=$(ls dist/)

pushd dist
# Prepare manywheel
manylinux_plat=manylinux2014_x86_64
if [[ "$CU_VERSION" == "xpu" ]]; then
manylinux_plat=manylinux_2_28_x86_64
fi
auditwheel repair --plat "$manylinux_plat" -w . \
pushd dist
manylinux_plat=manylinux2014_x86_64
auditwheel repair --plat "$manylinux_plat" -w . \
--exclude libtorch.so \
--exclude libtorch_python.so \
--exclude libtorch_cuda.so \
Expand All @@ -26,10 +25,11 @@ auditwheel repair --plat "$manylinux_plat" -w . \
--exclude libcudart.so.11.0 \
"${WHEEL_NAME}"

ls -lah .
# Clean up the linux_x86_64 wheel
rm "${WHEEL_NAME}"
popd
ls -lah .
# Clean up the linux_x86_64 wheel
rm "${WHEEL_NAME}"
popd
fi

MANYWHEEL_NAME=$(ls dist/)
# Try to install the new wheel
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def get_extensions():
if use_cuda:
sources += cuda_sources

if len(sources) == 0:
return None

ext_modules = [
extension(
"torchao._C",
Expand Down
6 changes: 4 additions & 2 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
)
if not _IS_FBCODE:
try:
from . import _C
from pathlib import Path
so_files = list(Path(__file__).parent.glob("_C*.so"))
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
torch.ops.load_library(so_files[0])
from . import ops
except:
_C = None
logging.info("Skipping import of cpp extensions")

from torchao.quantization import (
Expand Down
3 changes: 0 additions & 3 deletions torchao/csrc/init.cpp

This file was deleted.

Loading