diff --git a/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py b/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py index a253054d..712a87e2 100644 --- a/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py @@ -141,8 +141,11 @@ def is_valid_padding(padding: list[int]): # Only wrap in a composite when the underlying converter can handle it. # TODO We should be able to remove this if the converter can inline composites when it can not handle them. - # We don't cover any cases where ceil_mode is True or divisor_override is set. - if full_kwargs["ceil_mode"] or full_kwargs["divisor_override"] is not None: + # We don't cover any cases where the divisor_override is set. + if full_kwargs["divisor_override"] is not None: + return op(*args, **kwargs) + + if full_kwargs["ceil_mode"] and not full_kwargs["count_include_pad"]: return op(*args, **kwargs) # We also can not cover a case where count_include_pad is False but the padding is custom. diff --git a/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py b/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py index 90adafca..90041f5c 100644 --- a/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py +++ b/ai_edge_torch/convert/fx_passes/test/test_build_aten_composite_pass.py @@ -95,6 +95,21 @@ def test_avg_pool2d_op(self): ) self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1) + def test_avg_pool2d_ceil_mode(self): + stablehlo = _export_to_stablehlo_with_composite( + lambda x: torch.nn.functional.avg_pool2d( + x, + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ), + (torch.rand(1, 3, 6, 6),), + ) + self.assertTrue(stablehlo.count('stablehlo.composite "aten.avg_pool2d.default"'), 1) + if __name__ == '__main__': unittest.main() diff --git a/ai_edge_torch/convert/test/test_convert_composites.py b/ai_edge_torch/convert/test/test_convert_composites.py index 8a25ec7c..08456d5e 100644 --- a/ai_edge_torch/convert/test/test_convert_composites.py +++ b/ai_edge_torch/convert/test/test_convert_composites.py @@ -51,6 +51,7 @@ def test_convert_hardswish(self): @parameterized.parameterized.expand( [ + # input_size, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override # no padding, stride = 1 ([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None), # add stride @@ -67,6 +68,8 @@ def test_convert_hardswish(self): ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None), # ceil_mode = True ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], True, True, None), + # ceil_mode = True, stride=[3, 3] + ([1, 3, 6, 6], [3, 3], [3, 3], [1, 1], True, True, None), # set divisor_override ([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6), # padding set to one number diff --git a/requirements.txt b/requirements.txt index d99ffcfb..519fbe6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ numpy tabulate safetensors --pre -tf-nightly==2.17.0.dev20240509 +tf-nightly==2.17.0.dev20240529 -f https://download.pytorch.org/whl/nightly/torch_nightly.html torch==2.4.0.dev20240429+cpu -f https://download.pytorch.org/whl/nightly/torch_nightly.html