diff --git a/ai_edge_torch/__init__.py b/ai_edge_torch/__init__.py index 15385734..99580428 100644 --- a/ai_edge_torch/__init__.py +++ b/ai_edge_torch/__init__.py @@ -15,6 +15,7 @@ from .convert.converter import convert from .convert.converter import signature +from .convert.to_channel_last_io import to_channel_last_io from .model import Model diff --git a/ai_edge_torch/convert/test/test_to_channel_last_io.py b/ai_edge_torch/convert/test/test_to_channel_last_io.py new file mode 100644 index 00000000..755e7c21 --- /dev/null +++ b/ai_edge_torch/convert/test/test_to_channel_last_io.py @@ -0,0 +1,96 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest + +import torch + +import ai_edge_torch + + +class Identity(torch.nn.Module): + + def forward(self, x): + return x + + +class TestToChannelLastIO(unittest.TestCase): + """Tests to_channel_last_io API and module wrapper.""" + + def test_no_transformations(self): + x = torch.rand(1, 3, 10, 10) + y = ai_edge_torch.to_channel_last_io(Identity())(x) + self.assertEqual(y.shape, (1, 3, 10, 10)) + + def test_args(self): + x = torch.rand(1, 10, 10, 3) + y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x) + self.assertEqual(y.shape, (1, 3, 10, 10)) + + def test_outputs(self): + x = torch.rand(1, 3, 10, 10) + y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x) + self.assertEqual(y.shape, (1, 10, 10, 3)) + + def test_args_outputs(self): + x = torch.rand(1, 10, 10, 3) + y = ai_edge_torch.to_channel_last_io(Identity(), args=[0], outputs=[0])(x) + self.assertEqual(y.shape, (1, 10, 10, 3)) + + def test_args_5d(self): + x = torch.rand(1, 10, 10, 10, 3) + y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x) + self.assertEqual(y.shape, (1, 3, 10, 10, 10)) + + def test_outputs_5d(self): + x = torch.rand(1, 3, 10, 10, 10) + y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x) + self.assertEqual(y.shape, (1, 10, 10, 10, 3)) + + def test_chained_wrappers(self): + x = torch.rand(1, 10, 10, 3) + + m = Identity() + m = ai_edge_torch.to_channel_last_io(m, args=[0]) + m = ai_edge_torch.to_channel_last_io(m, outputs=[0]) + + y = m(x) + self.assertEqual(y.shape, (1, 10, 10, 3)) + + def test_list_args(self): + class Add(torch.nn.Module): + + def forward(self, x, y): + return x + y + + x = (torch.rand(1, 10, 10, 3), torch.rand(1, 10, 10, 3)) + y = ai_edge_torch.to_channel_last_io(Add(), args=[0, 1])(*x) + self.assertEqual(y.shape, (1, 3, 10, 10)) + + def test_list_outputs(self): + class TwoIdentity(torch.nn.Module): + + def forward(self, x): + return x, x + + x = torch.rand(1, 3, 10, 10) + y = ai_edge_torch.to_channel_last_io(TwoIdentity(), outputs=[0])(x) + self.assertIsInstance(y, tuple) + self.assertEqual(y[0].shape, (1, 10, 10, 3)) + self.assertEqual(y[1].shape, (1, 3, 10, 10)) + + +if __name__ == "__main__": + unittest.main() diff --git a/ai_edge_torch/convert/to_channel_last_io.py b/ai_edge_torch/convert/to_channel_last_io.py new file mode 100644 index 00000000..1280b633 --- /dev/null +++ b/ai_edge_torch/convert/to_channel_last_io.py @@ -0,0 +1,85 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Optional + +import torch +from torch import nn + + +class ChannelLastIOWrapper(nn.Module): + + def __init__(self, wrapped, *, args=None, outputs=None): + super().__init__() + self.wrapped = wrapped + self._args = args or [] + self._outputs = outputs or [] + + def _to_channel_last(self, x): + if not torch.is_tensor(x): + raise ValueError("Input must be a torch tensor") + if x.ndim < 3: + raise ValueError("Input must be a tensor with rank >= 3 in layout (N, C, ...)") + dims = [0, *range(2, x.ndim), 1] + return torch.permute(x, dims) + + def _to_channel_first(self, x): + if not torch.is_tensor(x): + raise ValueError("Input must be a torch tensor.") + if x.ndim < 3: + raise ValueError("Input must be a tensor with rank >= 3 in layout (N, ..., C)") + dims = [0, x.ndim - 1, *range(1, x.ndim - 1)] + return torch.permute(x, dims) + + def forward(self, *args, **kwargs): + args = list(args) + for i in self._args: + args[i] = self._to_channel_first(args[i]) + + outputs = self.wrapped(*args, **kwargs) + + if not isinstance(outputs, (list, tuple)): + outputs_is_list = False + output_list = [outputs] + else: + outputs_is_list = True + output_list = list(outputs) + + for i in self._outputs: + output_list[i] = self._to_channel_last(output_list[i]) + + if not outputs_is_list: + return output_list[0] + else: + return type(outputs)(output_list) + + +def to_channel_last_io( + module: nn.Module, + args: Optional[list[int]] = None, + outputs: Optional[list[int]] = None, +): + """Wraps the module with channel first to channel last layout transformations. + + Args: + args (list[int]): Transform args with indices in the list from channel first + (N, C, ...) to channel last (N, ..., C). + outputs (list[int]): Transform outputs with indices in the list from channel + first (N, C, ...) to channel last (N, ..., C). + Returns: + The wrapped nn.Module with additional layout transposes after inputs and/or before + outputs. + """ + return ChannelLastIOWrapper(module, args=args, outputs=outputs)