Skip to content

Commit

Permalink
Update outdated custom ops tutorials to point to the new landing page (
Browse files Browse the repository at this point in the history
…#2953)

* Update outdated custom ops tutorials to point to the new landing page
* Also turns on verification for the python custom ops tutorials.
* Update intermediate_source/torch_export_tutorial.py

---------

Co-authored-by: Svetlana Karslioglu <[email protected]>
  • Loading branch information
zou3519 and svekars authored Jul 22, 2024
1 parent 0a5b58b commit 0dee5c9
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 20 deletions.
1 change: 0 additions & 1 deletion .jenkins/validate_tutorials_built.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"intermediate_source/fx_conv_bn_fuser",
"intermediate_source/_torch_export_nightly_tutorial", # does not work on release
"advanced_source/super_resolution_with_onnxruntime",
"advanced_source/python_custom_ops", # https://github.com/pytorch/pytorch/issues/127443
"advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker
"prototype_source/fx_graph_mode_ptq_dynamic",
"prototype_source/vmap_recipe",
Expand Down
2 changes: 1 addition & 1 deletion advanced_source/cpp_custom_ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,4 +417,4 @@ Conclusion
In this tutorial, we went over the recommended approach to integrating Custom C++
and CUDA operators with PyTorch. The ``TORCH_LIBRARY/torch.library`` APIs are fairly
low-level. For more information about how to use the API, see
`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_.
`The Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual>`_.
6 changes: 5 additions & 1 deletion advanced_source/cpp_extension.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ Custom C++ and CUDA Extensions
==============================
**Author**: `Peter Goldsborough <https://www.goldsborough.me/>`_

.. warning::

This tutorial is deprecated as of PyTorch 2.4. Please see :ref:`custom-ops-landing-page`
for the newest up-to-date guides on extending PyTorch with Custom C++/CUDA Extensions.

PyTorch provides a plethora of operations related to neural networks, arbitrary
tensor algebra, data wrangling and other purposes. However, you may still find
Expand Down Expand Up @@ -225,7 +229,7 @@ Instead of:
Currently open issue for nvcc bug `here
<https://github.com/pytorch/pytorch/issues/69460>`_.
Complete workaround code example `here
<https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48>`_.
<https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48>`_.

Forward Pass
************
Expand Down
13 changes: 7 additions & 6 deletions advanced_source/custom_ops_landing_page.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.. _custom-ops-landing-page:

PyTorch Custom Operators Landing Page
=====================================
PyTorch Custom Operators
===========================

PyTorch offers a large library of operators that work on Tensors (e.g. ``torch.add``,
``torch.sum``, etc). However, you may wish to bring a new custom operation to PyTorch
Expand All @@ -10,26 +10,27 @@ In order to do so, you must register the custom operation with PyTorch via the P
`torch.library docs <https://pytorch.org/docs/stable/library.html>`_ or C++ ``TORCH_LIBRARY``
APIs.

TL;DR
-----


Authoring a custom operator from Python
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Please see :ref:`python-custom-ops-tutorial`.

You may wish to author a custom operator from Python (as opposed to C++) if:

- you have a Python function you want PyTorch to treat as an opaque callable, especially with
respect to ``torch.compile`` and ``torch.export``.
respect to ``torch.compile`` and ``torch.export``.
- you have some Python bindings to C++/CUDA kernels and want those to compose with PyTorch
subsystems (like ``torch.compile`` or ``torch.autograd``)
subsystems (like ``torch.compile`` or ``torch.autograd``)

Integrating custom C++ and/or CUDA code with PyTorch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Please see :ref:`cpp-custom-ops-tutorial`.

You may wish to author a custom operator from C++ (as opposed to Python) if:

- you have custom C++ and/or CUDA code.
- you plan to use this code with ``AOTInductor`` to do Python-less inference.

Expand Down
5 changes: 5 additions & 0 deletions advanced_source/dispatcher.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Registering a Dispatched Operator in C++
========================================

.. warning::

This tutorial is deprecated as of PyTorch 2.4. Please see :ref:`custom-ops-landing-page`
for the newest up-to-date guides on extending PyTorch with Custom Operators.

The dispatcher is an internal component of PyTorch which is responsible for
figuring out what code should actually get run when you call a function like
``torch::add``. This can be nontrivial, because PyTorch operations need
Expand Down
2 changes: 1 addition & 1 deletion advanced_source/python_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,5 +260,5 @@ def f(x):
# For more detailed information, see:
#
# - `the torch.library documentation <https://pytorch.org/docs/stable/library.html>`_
# - `the Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
# - `the Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual>`_
#
5 changes: 5 additions & 0 deletions advanced_source/torch_script_custom_ops.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Extending TorchScript with Custom C++ Operators
===============================================

.. warning::

This tutorial is deprecated as of PyTorch 2.4. Please see :ref:`custom-ops-landing-page`
for the newest up-to-date guides on PyTorch Custom Operators.

The PyTorch 1.0 release introduced a new programming model to PyTorch called
`TorchScript <https://pytorch.org/docs/master/jit.html>`_. TorchScript is a
subset of the Python programming language which can be parsed, compiled and
Expand Down
14 changes: 4 additions & 10 deletions intermediate_source/torch_export_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,25 +544,19 @@ def suggested_fixes():
#
# Currently, the steps to register a custom op for use by ``torch.export`` are:
#
# - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
# - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html>`__)
# as with any other custom op

from torch.library import Library, impl, impl_abstract

m = Library("my_custom_library", "DEF")

m.define("custom_op(Tensor input) -> Tensor")

@impl(m, "custom_op", "CompositeExplicitAutograd")
def custom_op(x):
@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(input: torch.Tensor) -> torch.Tensor:
print("custom_op called!")
return torch.relu(x)

######################################################################
# - Define a ``"Meta"`` implementation of the custom op that returns an empty
# tensor with the same shape as the expected output

@impl_abstract("my_custom_library::custom_op")
@custom_op.register_fake
def custom_op_meta(x):
return torch.empty_like(x)

Expand Down

0 comments on commit 0dee5c9

Please sign in to comment.