Skip to content

Commit

Permalink
Improve _safe_softmax lowering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704786941
  • Loading branch information
chunnienc authored and copybara-github committed Dec 10, 2024
1 parent cf0e73f commit 93d5756
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 33 deletions.
2 changes: 1 addition & 1 deletion ai_edge_torch/odml_torch/lowerings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
from . import context
from . import registry
from . import utils
from .registry import decompositions
from .decomp import decompositions
from .registry import lookup
from .registry import lower
59 changes: 59 additions & 0 deletions ai_edge_torch/odml_torch/lowerings/decomp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Torch export decompositions that needs to be run before lowering."""

import functools

import torch


@functools.cache
def decompositions():
# Base: Core ATen decompositions
decompositions = torch._decomp.core_aten_decompositions()

decompositions.update(
torch._decomp.get_decompositions([
torch.ops.aten.upsample_nearest2d,
torch.ops.aten._native_batch_norm_legit.no_stats,
torch.ops.aten._native_batch_norm_legit_functional,
torch.ops.aten._adaptive_avg_pool2d,
torch.ops.aten._adaptive_avg_pool3d,
torch.ops.aten.grid_sampler_2d,
torch.ops.aten.native_group_norm,
torch.ops.aten.native_dropout,
torch.ops.aten.reflection_pad1d,
torch.ops.aten.reflection_pad2d,
torch.ops.aten.reflection_pad3d,
torch.ops.aten.replication_pad1d,
torch.ops.aten.replication_pad2d,
torch.ops.aten.replication_pad3d,
torch.ops.aten.addmm,
])
)

torch._decomp.remove_decompositions(
decompositions,
[torch.ops.aten.roll],
)

# Override _safe_softmax decompositions with regular softmax.
# _safe_softmax introduces additional check-select ops to guard extreme
# input values to softmax, which could make the converted model inefficient
# on-device.
if hasattr(torch.ops.aten, "_safe_softmax"):
decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax

return decompositions
32 changes: 0 additions & 32 deletions ai_edge_torch/odml_torch/lowerings/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class LoweringRegistry:

def __init__(self):
self.registered_ops = {}
self.decompositions = {}

def lookup(self, op_or_name):
candidate = self._get_lowering(op_or_name)
Expand All @@ -52,33 +51,6 @@ def register(self, op, lowering):


global_registry = LoweringRegistry()
global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
global_registry.decompositions.update(
torch._decomp.get_decompositions([
torch.ops.aten.upsample_nearest2d,
torch.ops.aten._native_batch_norm_legit.no_stats,
torch.ops.aten._native_batch_norm_legit_functional,
torch.ops.aten._adaptive_avg_pool2d,
torch.ops.aten._adaptive_avg_pool3d,
torch.ops.aten.grid_sampler_2d,
torch.ops.aten.native_group_norm,
torch.ops.aten.native_dropout,
torch.ops.aten.reflection_pad1d,
torch.ops.aten.reflection_pad2d,
torch.ops.aten.reflection_pad3d,
torch.ops.aten.replication_pad1d,
torch.ops.aten.replication_pad2d,
torch.ops.aten.replication_pad3d,
torch.ops.aten.addmm,
])
)

torch._decomp.remove_decompositions(
global_registry.decompositions,
[
torch.ops.aten.roll,
],
)


def lookup(op):
Expand All @@ -91,7 +63,3 @@ def inner(lowering: Callable[[context.LoweringContext, ...], Any]):
return lowering

return inner


def decompositions():
return global_registry.decompositions

0 comments on commit 93d5756

Please sign in to comment.