Skip to content

Commit

Permalink
Add channel last IO transformation API (#66)
Browse files Browse the repository at this point in the history
* init

* fix

* Update to_channel_last_io.py

* Update test_to_channel_last_io.py

* Fix error message
  • Loading branch information
chunnienc authored Jun 27, 2024
1 parent af276a6 commit b5c7314
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
1 change: 1 addition & 0 deletions ai_edge_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .convert.converter import convert
from .convert.converter import signature
from .convert.to_channel_last_io import to_channel_last_io
from .model import Model


Expand Down
96 changes: 96 additions & 0 deletions ai_edge_torch/convert/test/test_to_channel_last_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import unittest

import torch

import ai_edge_torch


class Identity(torch.nn.Module):

def forward(self, x):
return x


class TestToChannelLastIO(unittest.TestCase):
"""Tests to_channel_last_io API and module wrapper."""

def test_no_transformations(self):
x = torch.rand(1, 3, 10, 10)
y = ai_edge_torch.to_channel_last_io(Identity())(x)
self.assertEqual(y.shape, (1, 3, 10, 10))

def test_args(self):
x = torch.rand(1, 10, 10, 3)
y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x)
self.assertEqual(y.shape, (1, 3, 10, 10))

def test_outputs(self):
x = torch.rand(1, 3, 10, 10)
y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x)
self.assertEqual(y.shape, (1, 10, 10, 3))

def test_args_outputs(self):
x = torch.rand(1, 10, 10, 3)
y = ai_edge_torch.to_channel_last_io(Identity(), args=[0], outputs=[0])(x)
self.assertEqual(y.shape, (1, 10, 10, 3))

def test_args_5d(self):
x = torch.rand(1, 10, 10, 10, 3)
y = ai_edge_torch.to_channel_last_io(Identity(), args=[0])(x)
self.assertEqual(y.shape, (1, 3, 10, 10, 10))

def test_outputs_5d(self):
x = torch.rand(1, 3, 10, 10, 10)
y = ai_edge_torch.to_channel_last_io(Identity(), outputs=[0])(x)
self.assertEqual(y.shape, (1, 10, 10, 10, 3))

def test_chained_wrappers(self):
x = torch.rand(1, 10, 10, 3)

m = Identity()
m = ai_edge_torch.to_channel_last_io(m, args=[0])
m = ai_edge_torch.to_channel_last_io(m, outputs=[0])

y = m(x)
self.assertEqual(y.shape, (1, 10, 10, 3))

def test_list_args(self):
class Add(torch.nn.Module):

def forward(self, x, y):
return x + y

x = (torch.rand(1, 10, 10, 3), torch.rand(1, 10, 10, 3))
y = ai_edge_torch.to_channel_last_io(Add(), args=[0, 1])(*x)
self.assertEqual(y.shape, (1, 3, 10, 10))

def test_list_outputs(self):
class TwoIdentity(torch.nn.Module):

def forward(self, x):
return x, x

x = torch.rand(1, 3, 10, 10)
y = ai_edge_torch.to_channel_last_io(TwoIdentity(), outputs=[0])(x)
self.assertIsInstance(y, tuple)
self.assertEqual(y[0].shape, (1, 10, 10, 3))
self.assertEqual(y[1].shape, (1, 3, 10, 10))


if __name__ == "__main__":
unittest.main()
85 changes: 85 additions & 0 deletions ai_edge_torch/convert/to_channel_last_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Optional

import torch
from torch import nn


class ChannelLastIOWrapper(nn.Module):

def __init__(self, wrapped, *, args=None, outputs=None):
super().__init__()
self.wrapped = wrapped
self._args = args or []
self._outputs = outputs or []

def _to_channel_last(self, x):
if not torch.is_tensor(x):
raise ValueError("Input must be a torch tensor")
if x.ndim < 3:
raise ValueError("Input must be a tensor with rank >= 3 in layout (N, C, ...)")
dims = [0, *range(2, x.ndim), 1]
return torch.permute(x, dims)

def _to_channel_first(self, x):
if not torch.is_tensor(x):
raise ValueError("Input must be a torch tensor.")
if x.ndim < 3:
raise ValueError("Input must be a tensor with rank >= 3 in layout (N, ..., C)")
dims = [0, x.ndim - 1, *range(1, x.ndim - 1)]
return torch.permute(x, dims)

def forward(self, *args, **kwargs):
args = list(args)
for i in self._args:
args[i] = self._to_channel_first(args[i])

outputs = self.wrapped(*args, **kwargs)

if not isinstance(outputs, (list, tuple)):
outputs_is_list = False
output_list = [outputs]
else:
outputs_is_list = True
output_list = list(outputs)

for i in self._outputs:
output_list[i] = self._to_channel_last(output_list[i])

if not outputs_is_list:
return output_list[0]
else:
return type(outputs)(output_list)


def to_channel_last_io(
module: nn.Module,
args: Optional[list[int]] = None,
outputs: Optional[list[int]] = None,
):
"""Wraps the module with channel first to channel last layout transformations.
Args:
args (list[int]): Transform args with indices in the list from channel first
(N, C, ...) to channel last (N, ..., C).
outputs (list[int]): Transform outputs with indices in the list from channel
first (N, C, ...) to channel last (N, ..., C).
Returns:
The wrapped nn.Module with additional layout transposes after inputs and/or before
outputs.
"""
return ChannelLastIOWrapper(module, args=args, outputs=outputs)

0 comments on commit b5c7314

Please sign in to comment.