Skip to content

Commit

Permalink
louder warning + docs for custom cuda extensions (pytorch#186)
Browse files Browse the repository at this point in the history
* louder warning for missing cudatoolkit

* docs for custom ops
  • Loading branch information
msaroufim authored Apr 29, 2024
1 parent 8f67d29 commit bc567dd
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install torchao
```Shell
git clone https://github.com/pytorch-labs/ao
cd ao
pip install -e .
python setup.py develop
```

## Key Features
Expand All @@ -35,17 +35,18 @@ The library provides
* High level `autoquant` API and kernel auto tuner targeting SOTA performance across varying model shapes on consumer/enterprise GPUs.
3. [Sparsity algorithms](./torchao/sparsity) such as Wanda that help improve accuracy of sparse networks
4. Integration with other PyTorch native libraries like [torchtune](https://github.com/pytorch/torchtune) and [ExecuTorch](https://github.com/pytorch/executorch)
5. [Custom C++/CUDA Extension support](./torchao/csrc/)


## Our Goals
torchao embodies PyTorch’s design philosophy [details](https://pytorch.org/docs/stable/community/design.html), especially "usability over everything else". Our vision for this repository is the following:

* Composability: Native solutions for optimization techniques that compose with both `torch.compile` and `FSDP`
* Composability: Native solutions for optimization techniques that compose with both `torch.compile` and `FSDP`
* For example, for QLoRA for new dtypes support
* Interoperability: Work with the rest of the PyTorch ecosystem such as torchtune, gpt-fast and ExecuTorch
* Transparent Benchmarks: Regularly run performance benchmarking of our APIs across a suite of Torchbench models and across hardware backends
* Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch).
* Infrastructure Support: Release packaging solution for kernels and a CI/CD setup that runs these kernels on different backends.
* Infrastructure Support: Release packaging solution for kernels and a CI/CD setup that runs these kernels on different backends.

## Interoperability with PyTorch Libraries

Expand Down
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def get_extensions():
if debug_mode:
print("Compiling in debug mode")

# TODO: And cudatoolkit is available
if not torch.cuda.is_available():
print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions")
if CUDA_HOME is None and torch.cuda.is_available():
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit")

use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

Expand Down
29 changes: 29 additions & 0 deletions torchao/csrc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Custom C++/CUDA Extensions

This folder is an example of how to integrate your own custom kernels into ao such that
1. They work on as many devices and operating systems as possible
2. They compose with `torch.compile()` without graph breaks

The goal is that you can focus on just writing your custom CUDA or C++ kernel and we can package it up so it's available via `torchao.ops.your_custom_kernel`.


## How to add your own kernel in ao

We've integrated a test kernel which implements a non-maximum supression (NMS) op which you can use as a template for your own kernels.

1. Install the cudatoolkit https://anaconda.org/conda-forge/cudatoolkit
2. In `csrc/cuda` author your custom kernel and ensure you expose a `TORCH_LIBRARY_IMPL` which will expose `torchao::your_custom_kernel`
3. In `csrc/` author a `cpp` stub which will include a `TORCH_LIBRARY_FRAGMENT` which will place your custom kernel in the `torchao.ops` namespace and also expose a public function with the right arguments
4. In `torchao/ops.py` is where you'll expose the python API which your new end users will leverage
5. Write a new test in `test/test_ops.py` which most importantly needs to pass `opcheck()`, this ensures that your custom kernel composes out of the box with `torch.compile()`

And that's it! Once CI passes and your code merged you'll be able to point people to `torchao.ops.your_custom_kernel`. If you're working on an interesting kernel and would like someone else to handle the release and package management please feel free to open an issue.

If you'd like to learn more please check out [torch.library](https://pytorch.org/docs/main/library.html)

## Required dependencies

The important dependencies are already taken care of in our CI so feel free to test in CI directly

1. cudatoolkit so you can build your own custom extensions locally. We highly recommend using https://anaconda.org/conda-forge/cudatoolkit for installation
2. manylinux with CUDA support. In your own Github actions you can integrate this support using `uses: pytorch/test-infra/.github/workflows/linux_job.yml@main`

0 comments on commit bc567dd

Please sign in to comment.