Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add torch.compile support #1791

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions docs/source/using_doctr/using_model_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ Advantages:
.. code:: python3

import torch
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True).cuda().half()
predictor = ocr_predictor(
reco_arch="crnn_mobilenet_v3_small",
det_arch="linknet_resnet34",
pretrained=True
).cuda().half()
res = predictor(doc)

.. tab:: TensorFlow
Expand All @@ -41,8 +45,63 @@ Advantages:
import tensorflow as tf
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True)

predictor = ocr_predictor(
reco_arch="crnn_mobilenet_v3_small",
det_arch="linknet_resnet34",
pretrained=True
)


Compiling your models (PyTorch only)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

**NOTE:**

- This feature is only available if you use PyTorch as backend.
- The recognition architecture `master` is not supported for model compilation yet.
- We provide only official support for the default (`inductor`) backend, but you can try other backends, configurations depending on your hardware and requirements as well.

Compiling your PyTorch models with `torch.compile` optimizes the model by converting it to a graph representation and applying backends that can improve performance.
This process can make inference faster and reduce memory overhead during execution.

Further information can be found in the `PyTorch documentation <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_.

.. code::

import torch
from doctr.models import (
ocr_predictor,
vitstr_small,
fast_base,
mobilenet_v3_small_crop_orientation,
mobilenet_v3_small_page_orientation,
crop_orientation_predictor,
page_orientation_predictor
)

# Compile the models
detection_model = torch.compile(
fast_base(pretrained=True).eval()
)
recognition_model = torch.compile(
vitstr_small(pretrained=True).eval()
)
crop_orientation_model = torch.compile(
mobilenet_v3_small_crop_orientation(pretrained=True).eval()
)
page_orientation_model = torch.compile(
mobilenet_v3_small_page_orientation(pretrained=True).eval()
)

predictor = models.ocr_predictor(
detection_model, recognition_model, assume_straight_pages=False
)
# NOTE: Only required for non-straight pages (`assume_straight_pages=False`) and non-disabled orientation classification
# Set the orientation predictors
predictor.crop_orientation_predictor = crop_orientation_predictor(crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(page_orientation_model)

compiled_out = predictor(doc)

Export to ONNX
^^^^^^^^^^^^^^
Expand All @@ -64,7 +123,11 @@ It defines a common format for representing models, including the network struct
input_shape = (3, 32, 128)
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = torch.rand((batch_size, input_shape), dtype=torch.float32)
model_path = export_model_to_onnx(model, model_name="vitstr.onnx, dummy_input=dummy_input)
model_path = export_model_to_onnx(
model,
model_name="vitstr.onnx",
dummy_input=dummy_input
)

.. tab:: TensorFlow

Expand All @@ -78,7 +141,11 @@ It defines a common format for representing models, including the network struct
input_shape = (32, 128, 3)
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = [tf.TensorSpec([batch_size, input_shape], tf.float32, name="input")]
model_path, output = export_model_to_onnx(model, model_name="vitstr.onnx", dummy_input=dummy_input)
model_path, output = export_model_to_onnx(
model,
model_name="vitstr.onnx",
dummy_input=dummy_input
)


Using your ONNX exported model
Expand Down
12 changes: 6 additions & 6 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ For instance, this snippet instantiates an end-to-end ocr_predictor working with

.. code:: python3

from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor('linknet_resnet18', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True)


Expand All @@ -309,7 +309,7 @@ Additionally, you can change the batch size of the underlying detection and reco

.. code:: python3

from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor(pretrained=True, det_bs=4, reco_bs=1024)

To modify the output structure you can pass the following arguments to the predictor which will be handled by the underlying `DocumentBuilder`:
Expand All @@ -322,7 +322,7 @@ For example to disable the automatic grouping of lines into blocks:

.. code:: python3

from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor(pretrained=True, resolve_blocks=False)


Expand Down Expand Up @@ -477,7 +477,7 @@ This will only have an effect with `assume_straight_pages=False` and/or `straigh

.. code:: python3

from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_page_orientation=True)


Expand All @@ -489,15 +489,15 @@ This will only have an effect with `assume_straight_pages=False` and/or `straigh

.. code:: python3

from doctr.model import ocr_predictor
from doctr.models import ocr_predictor
model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_crop_orientation=True)


* Add a hook to the `ocr_predictor` to manipulate the location predictions before the crops are passed to the recognition model.

.. code:: python3

from doctr.model import ocr_predictor
from doctr.models import ocr_predictor

class CustomHook:
def __call__(self, loc_preds):
Expand Down
11 changes: 9 additions & 2 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any

from doctr.file_utils import is_tf_available
from doctr.file_utils import is_tf_available, is_torch_available

from .. import classification
from ..preprocessor import PreProcessor
Expand Down Expand Up @@ -48,7 +48,14 @@ def _orientation_predictor(
# Load directly classifier from backbone
_model = classification.__dict__[arch](pretrained=pretrained)
else:
if not isinstance(arch, classification.MobileNetV3):
allowed_archs = [classification.MobileNetV3]
if is_torch_available():
# Adding the type for torch compiled models to the allowed architectures
from doctr.models.utils import _CompiledModule

allowed_archs.append(_CompiledModule)

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

Expand Down
13 changes: 9 additions & 4 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,16 @@ def forward(
out["out_map"] = prob_map

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]

# Post-process boxes (keep only text predictions)
out["preds"] = [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]
out["preds"] = _postprocess(prob_map)

if target is not None:
thresh_map = self.thresh_head(feat_concat)
Expand Down
13 changes: 9 additions & 4 deletions doctr/models/detection/fast/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,16 @@ def forward(
out["out_map"] = prob_map

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

post proc can't be compiled so model runs fully compiled and all the other parts keept as is

def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]

# Post-process boxes (keep only text predictions)
out["preds"] = [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]
out["preds"] = _postprocess(prob_map)

if target is not None:
loss = self.compute_loss(logits, target)
Expand Down
15 changes: 10 additions & 5 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,16 @@ def forward(
out["out_map"] = prob_map

if target is None or return_preds:
# Post-process boxes
out["preds"] = [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]

# Post-process boxes (keep only text predictions)
out["preds"] = _postprocess(prob_map)

if target is not None:
loss = self.compute_loss(logits, target)
Expand Down
9 changes: 8 additions & 1 deletion doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
if isinstance(_model, detection.FAST):
_model = reparameterize(_model)
else:
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
if is_torch_available():
# Adding the type for torch compiled models to the allowed architectures
from doctr.models.utils import _CompiledModule

allowed_archs.append(_CompiledModule)

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")

_model = arch
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/recognition/crnn/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,13 @@ def forward(
out["out_map"] = logits

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(logits)

# Post-process boxes
out["preds"] = self.postprocessor(logits)
out["preds"] = _postprocess(logits)

if target is not None:
out["loss"] = self.compute_loss(logits, target)
Expand Down
8 changes: 7 additions & 1 deletion doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,13 @@ def forward(
out["out_map"] = logits

if return_preds:
out["preds"] = self.postprocessor(logits)
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(logits)

# Post-process boxes
out["preds"] = _postprocess(logits)

return out

Expand Down
7 changes: 6 additions & 1 deletion doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,13 @@ def forward(
out["out_map"] = logits

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(logits)

# Post-process boxes
out["preds"] = self.postprocessor(logits)
out["preds"] = _postprocess(logits)

if target is not None:
out["loss"] = loss
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,13 @@ def forward(
out["out_map"] = decoded_features

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(decoded_features)

# Post-process boxes
out["preds"] = self.postprocessor(decoded_features)
out["preds"] = _postprocess(decoded_features)

if target is not None:
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
A merged character sequence.

Example::
>>> from doctr.model.recognition.utils import merge_sequences
>>> from doctr.models.recognition.utils import merge_sequences
>>> merge_sequences('abcd', 'cdefgh', 1.4)
'abcdefgh'
>>> merge_sequences('abcdi', 'cdefgh', 1.4)
Expand Down Expand Up @@ -70,7 +70,7 @@ def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str:
A merged character sequence

Example::
>>> from doctr.model.recognition.utils import merge_multi_sequences
>>> from doctr.models.recognition.utils import merge_multi_sequences
>>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4)
'abcdefghijkl'
"""
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@ def forward(
out["out_map"] = decoded_features

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(decoded_features)

# Post-process boxes
out["preds"] = self.postprocessor(decoded_features)
out["preds"] = _postprocess(decoded_features)

if target is not None:
out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
Expand Down
13 changes: 9 additions & 4 deletions doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any

from doctr.file_utils import is_tf_available
from doctr.file_utils import is_tf_available, is_torch_available
from doctr.models.preprocessor import PreProcessor

from .. import recognition
Expand Down Expand Up @@ -35,9 +35,14 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
)
else:
if not isinstance(
arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
):
allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
if is_torch_available():
# Adding the type for torch compiled models to the allowed architectures
from doctr.models.utils import _CompiledModule

allowed_archs.append(_CompiledModule)

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

Expand Down
4 changes: 4 additions & 0 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
"export_model_to_onnx",
"_copy_tensor",
"_bf16_to_float32",
"_CompiledModule",
]

# torch compiled model type
_CompiledModule = torch._dynamo.eval_frame.OptimizedModule


def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
return x.clone().detach()
Expand Down
Loading
Loading