Skip to content

Commit

Permalink
TRT support for MAISI (#8153)
Browse files Browse the repository at this point in the history
### Description
Added trt_compile() support for Lists and Tuples in arguments for
forward() - needed for MAISI.
Did not add support for grouping return results yet - MAISI worked with
explicit workaround unrolling the return results.

### Notes
To successfully export MAISI, either latest Torch nightly is needed, or
this patch needs to be applied to 24.09-based container:

```
--- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak     2024-10-09 01:38:04.920316673 +0000                                                   
+++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py      2024-10-09 01:38:25.228053951 +0000                                                   
@@ -148,7 +148,6 @@                                                                                                                                                   
         is_causal and symbolic_helper._is_none(attn_mask)                                                                                                            
     ), "is_causal and attn_mask cannot be set at the same time"                                                                                                      
                                                                                                                                                                      
-    scale = symbolic_helper._maybe_get_const(scale, "f")                                                                                                             
     if symbolic_helper._is_none(scale):                                                                                                                              
         scale = _attention_scale(g, query)                                                                                                                           
```

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).

---------

Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: binliunls <[email protected]>
  • Loading branch information
6 people authored Nov 14, 2024
1 parent 941e739 commit 746a97a
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 98 deletions.
6 changes: 6 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ RUN cp /tmp/requirements.txt /tmp/req.bak \
COPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./
COPY tests ./tests
COPY monai ./monai

# TODO: remove this line and torch.patch for 24.11
RUN patch -R -d /usr/local/lib/python3.10/dist-packages/torch/onnx/ < ./monai/torch.patch

RUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \
&& rm -rf build __pycache__

Expand All @@ -57,4 +61,6 @@ RUN apt-get update \
# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations
ENV PATH=${PATH}:/opt/tools
ENV POLYGRAPHY_AUTOINSTALL_DEPS=1


WORKDIR /opt/monai
1 change: 0 additions & 1 deletion monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor):
# [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d)
masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1)

return masks_embedding, class_embedding


Expand Down
212 changes: 158 additions & 54 deletions monai/networks/trt_compiler.py

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,6 @@ def convert_to_onnx(
use_trace: bool = True,
do_constant_folding: bool = True,
constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
dynamo=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -673,6 +672,9 @@ def convert_to_onnx(
# let torch.onnx.export to trace the model.
mode_to_export = model
torch_versioned_kwargs = kwargs
if "dynamo" in kwargs and kwargs["dynamo"] and verify:
torch_versioned_kwargs["verify"] = verify
verify = False
else:
if not pytorch_after(1, 10):
if "example_outputs" not in kwargs:
Expand All @@ -695,13 +697,13 @@ def convert_to_onnx(
f = temp_file.name
else:
f = filename

print(f"torch_versioned_kwargs={torch_versioned_kwargs}")
torch.onnx.export(
mode_to_export,
onnx_inputs,
f=f,
input_names=input_names,
output_names=output_names,
output_names=output_names or None,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
Expand All @@ -715,6 +717,9 @@ def convert_to_onnx(
fold_constants(onnx_model, size_threshold=constant_size_threshold)

if verify:
if isinstance(inputs, dict):
inputs = list(inputs.values())

if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
9 changes: 9 additions & 0 deletions monai/torch.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
--- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py 2024-10-31 06:09:21.139938791 +0000
+++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak 2024-10-31 06:01:50.207462739 +0000
@@ -150,6 +150,7 @@
), "is_causal and attn_mask cannot be set at the same time"
assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"

+ scale = symbolic_helper._maybe_get_const(scale, "f")
if symbolic_helper._is_none(scale):
scale = _attention_scale(g, query)
6 changes: 3 additions & 3 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s
current_ver_string: if None, the current system GPU CUDA compute capability will be used.
Returns:
True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
True if the current system GPU CUDA compute capability is greater than the specified version.
"""
if current_ver_string is None:
cuda_available = torch.cuda.is_available()
Expand All @@ -667,11 +667,11 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s

ver, has_ver = optional_import("packaging.version", name="parse")
if has_ver:
return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
while len(parts) < 2:
parts += ["0"]
c_major, c_minor = parts[:2]
c_mn = int(c_major), int(c_minor)
mn = int(major), int(minor)
return c_mn > mn
return c_mn >= mn
66 changes: 30 additions & 36 deletions tests/test_trt_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,32 @@

from monai.handlers import TrtHandler
from monai.networks import trt_compile
from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132
from monai.networks.nets import cell_sam_wrapper, vista3d132
from monai.utils import min_version, optional_import
from tests.utils import (
SkipIfAtLeastPyTorchVersion,
SkipIfBeforeComputeCapabilityVersion,
skip_if_no_cuda,
skip_if_quick,
skip_if_windows,
)
from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows

trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
torch_tensorrt, torch_trt_imported = optional_import("torch_tensorrt")
polygraphy, polygraphy_imported = optional_import("polygraphy")
build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")

TEST_CASE_1 = ["fp32"]
TEST_CASE_2 = ["fp16"]


class ListAdd(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = 0.1):
y1 = y.clone()
x1 = x.copy()
z1 = z + y
for xi in x:
y1 = y1 + xi + bs
return x1, [y1, z1], y1 + z1


@skip_if_windows
@skip_if_no_cuda
@skip_if_quick
Expand All @@ -53,7 +61,7 @@ def tearDown(self):
if current_device != self.gpu_device:
torch.cuda.set_device(self.gpu_device)

@SkipIfAtLeastPyTorchVersion((2, 4, 1))
@unittest.skipUnless(torch_trt_imported, "torch_tensorrt is required")
def test_handler(self):
from ignite.engine import Engine

Expand All @@ -74,29 +82,19 @@ def test_handler(self):
net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda"))
self.assertIsNotNone(net1._trt_compiler.engine)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_unet_value(self, precision):
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(2, 2, 4, 8, 4),
strides=(2, 2, 2, 2),
num_res_units=2,
norm="batch",
).cuda()
def test_lists(self):
model = ListAdd().cuda()

with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
model.eval()
input_example = torch.randn(2, 1, 96, 96, 96).cuda()
output_example = model(input_example)
args: dict = {"builder_optimization_level": 1}
trt_compile(
model,
f"{tmpdir}/test_unet_trt_compile",
args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]},
)
args = {"output_lists": [[-1], [2], []], "export_args": {"dynamo": False, "verbose": True}}
x = torch.randn(1, 16).to("cuda")
y = torch.randn(1, 16).to("cuda")
z = torch.randn(1, 16).to("cuda")
input_example = ([x, y, z], y.clone(), z.clone())
output_example = model(*input_example)
trt_compile(model, f"{tmpdir}/test_lists", args=args)
self.assertIsNone(model._trt_compiler.engine)
trt_output = model(input_example)
trt_output = model(*input_example)
# Check that lazy TRT build succeeded
self.assertIsNotNone(model._trt_compiler.engine)
torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)
Expand All @@ -109,11 +107,7 @@ def test_cell_sam_wrapper_value(self, precision):
model.eval()
input_example = torch.randn(1, 3, 128, 128).to("cuda")
output_example = model(input_example)
trt_compile(
model,
f"{tmpdir}/test_cell_sam_wrapper_trt_compile",
args={"precision": precision, "dynamic_batchsize": [1, 1, 1]},
)
trt_compile(model, f"{tmpdir}/test_cell_sam_wrapper_trt_compile", args={"precision": precision})
self.assertIsNone(model._trt_compiler.engine)
trt_output = model(input_example)
# Check that lazy TRT build succeeded
Expand All @@ -130,7 +124,7 @@ def test_vista3d(self, precision):
model = trt_compile(
model,
f"{tmpdir}/test_vista3d_trt_compile",
args={"precision": precision, "dynamic_batchsize": [1, 1, 1]},
args={"precision": precision, "dynamic_batchsize": [1, 2, 4]},
submodule=["image_encoder.encoder", "class_head"],
)
self.assertIsNotNone(model.image_encoder.encoder._trt_compiler)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_version_after.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

TEST_CASES_SM = [
# (major, minor, sm, expected)
(6, 1, "6.1", True),
(6, 1, "6.1", False),
(6, 1, "6.0", False),
(6, 0, "8.6", True),
(7, 0, "8", True),
Expand Down

0 comments on commit 746a97a

Please sign in to comment.