Skip to content

Commit

Permalink
Enable composite for avg_pool2d in ceil_mode (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
majiddadashi authored May 30, 2024
1 parent 8ad1966 commit 5f60e64
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 3 deletions.
7 changes: 5 additions & 2 deletions ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ def is_valid_padding(padding: list[int]):
# Only wrap in a composite when the underlying converter can handle it.
# TODO We should be able to remove this if the converter can inline composites when it can not handle them.

# We don't cover any cases where ceil_mode is True or divisor_override is set.
if full_kwargs["ceil_mode"] or full_kwargs["divisor_override"] is not None:
# We don't cover any cases where the divisor_override is set.
if full_kwargs["divisor_override"] is not None:
return op(*args, **kwargs)

if full_kwargs["ceil_mode"] and not full_kwargs["count_include_pad"]:
return op(*args, **kwargs)

# We also can not cover a case where count_include_pad is False but the padding is custom.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ def test_avg_pool2d_op(self):
)
self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1)

def test_avg_pool2d_ceil_mode(self):
stablehlo = _export_to_stablehlo_with_composite(
lambda x: torch.nn.functional.avg_pool2d(
x,
kernel_size=[3, 3],
stride=[1, 1],
padding=[1, 1],
ceil_mode=True,
count_include_pad=True,
divisor_override=None,
),
(torch.rand(1, 3, 6, 6),),
)
self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1)


if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions ai_edge_torch/convert/test/test_convert_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_convert_hardswish(self):

@parameterized.parameterized.expand(
[
# input_size, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
# no padding, stride = 1
([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None),
# add stride
Expand All @@ -67,6 +68,8 @@ def test_convert_hardswish(self):
([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None),
# ceil_mode = True
([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], True, True, None),
# ceil_mode = True, stride=[3, 3]
([1, 3, 6, 6], [3, 3], [3, 3], [1, 1], True, True, None),
# set divisor_override
([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6),
# padding set to one number
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ numpy
tabulate
safetensors
--pre
tf-nightly==2.17.0.dev20240509
tf-nightly==2.17.0.dev20240529
-f https://download.pytorch.org/whl/nightly/torch_nightly.html
torch==2.4.0.dev20240429+cpu
-f https://download.pytorch.org/whl/nightly/torch_nightly.html
Expand Down

0 comments on commit 5f60e64

Please sign in to comment.