diff --git a/ai_edge_torch/odml_torch/lowerings/__init__.py b/ai_edge_torch/odml_torch/lowerings/__init__.py index 84b50593..0d232d52 100644 --- a/ai_edge_torch/odml_torch/lowerings/__init__.py +++ b/ai_edge_torch/odml_torch/lowerings/__init__.py @@ -21,6 +21,6 @@ from . import context from . import registry from . import utils -from .registry import decompositions +from .decomp import decompositions from .registry import lookup from .registry import lower diff --git a/ai_edge_torch/odml_torch/lowerings/decomp.py b/ai_edge_torch/odml_torch/lowerings/decomp.py new file mode 100644 index 00000000..5dabb293 --- /dev/null +++ b/ai_edge_torch/odml_torch/lowerings/decomp.py @@ -0,0 +1,59 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Torch export decompositions to run before lowering.""" + +import functools + +import torch + + +@functools.cache +def decompositions(): + # Base: Core ATen decompositions + decompositions = torch._decomp.core_aten_decompositions() + + decompositions.update( + torch._decomp.get_decompositions([ + torch.ops.aten.upsample_nearest2d, + torch.ops.aten._native_batch_norm_legit.no_stats, + torch.ops.aten._native_batch_norm_legit_functional, + torch.ops.aten._adaptive_avg_pool2d, + torch.ops.aten._adaptive_avg_pool3d, + torch.ops.aten.grid_sampler_2d, + torch.ops.aten.native_group_norm, + torch.ops.aten.native_dropout, + torch.ops.aten.reflection_pad1d, + torch.ops.aten.reflection_pad2d, + torch.ops.aten.reflection_pad3d, + torch.ops.aten.replication_pad1d, + torch.ops.aten.replication_pad2d, + torch.ops.aten.replication_pad3d, + torch.ops.aten.addmm, + ]) + ) + + torch._decomp.remove_decompositions( + decompositions, + [torch.ops.aten.roll], + ) + + # Override _safe_softmax decompositions with regular softmax. + # _safe_softmax introduces additional check-select ops to guard extreme + # input values to softmax, which could make the converted model inefficient + # on-device. + if hasattr(torch.ops.aten, "_safe_softmax"): + decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax + + return decompositions diff --git a/ai_edge_torch/odml_torch/lowerings/registry.py b/ai_edge_torch/odml_torch/lowerings/registry.py index 2ab6c1bd..bf77eeac 100644 --- a/ai_edge_torch/odml_torch/lowerings/registry.py +++ b/ai_edge_torch/odml_torch/lowerings/registry.py @@ -26,7 +26,6 @@ class LoweringRegistry: def __init__(self): self.registered_ops = {} - self.decompositions = {} def lookup(self, op_or_name): candidate = self._get_lowering(op_or_name) @@ -52,33 +51,6 @@ def register(self, op, lowering): global_registry = LoweringRegistry() -global_registry.decompositions.update(torch._decomp.core_aten_decompositions()) -global_registry.decompositions.update( - torch._decomp.get_decompositions([ - torch.ops.aten.upsample_nearest2d, - torch.ops.aten._native_batch_norm_legit.no_stats, - torch.ops.aten._native_batch_norm_legit_functional, - torch.ops.aten._adaptive_avg_pool2d, - torch.ops.aten._adaptive_avg_pool3d, - torch.ops.aten.grid_sampler_2d, - torch.ops.aten.native_group_norm, - torch.ops.aten.native_dropout, - torch.ops.aten.reflection_pad1d, - torch.ops.aten.reflection_pad2d, - torch.ops.aten.reflection_pad3d, - torch.ops.aten.replication_pad1d, - torch.ops.aten.replication_pad2d, - torch.ops.aten.replication_pad3d, - torch.ops.aten.addmm, - ]) -) - -torch._decomp.remove_decompositions( - global_registry.decompositions, - [ - torch.ops.aten.roll, - ], -) def lookup(op): @@ -91,7 +63,3 @@ def inner(lowering: Callable[[context.LoweringContext, ...], Any]): return lowering return inner - - -def decompositions(): - return global_registry.decompositions