Skip to content

Commit

Permalink
fix hf distilbert conversion
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713052586
  • Loading branch information
chunnienc authored and copybara-github committed Jan 7, 2025
1 parent b91c2e9 commit b7b5eb2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ai_edge_torch/odml_torch/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
def in_i32(x: int):
return -2147483648 <= x <= 2147483647

def to_int32(x: torch.Tensor):
return torch.ops.aten._to_copy.default(x, dtype=torch.int32)

def rewrite_arange(node: torch.fx.Node):
tensor_meta = node.meta.get("tensor_meta", None)
if not tensor_meta:
Expand All @@ -249,7 +252,7 @@ def rewrite_arange(node: torch.fx.Node):
if not (in_i32(start) and in_i32(end)):
return
op = node.target
node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))

graph_module = exported_program.graph_module
for node in graph_module.graph.nodes:
Expand Down Expand Up @@ -305,9 +308,8 @@ def exported_program_to_mlir(

_convert_i64_to_i32(exported_program)

exported_program = _torch_future.safe_run_decompositions(
exported_program, lowerings.decompositions()
)
# No decompositions but just retracing/cananicalization.
exported_program = _torch_future.safe_run_decompositions(exported_program, {})

# Passes below mutate the exported program to a state not executable by torch.
# Do not call run_decompositions after applying the passes.
Expand Down
4 changes: 4 additions & 0 deletions ai_edge_torch/odml_torch/lowerings/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def decompositions():
],
)

# Override noop aten op decompositions for faster run_decompositions.
decompositions[torch.ops.aten.alias.default] = lambda x: x
decompositions[torch.ops.aten.detach.default] = lambda x: x

# Override _safe_softmax decompositions with regular softmax.
# _safe_softmax introduces additional check-select ops to guard extreme
# input values to softmax, which could make the converted model inefficient
Expand Down

0 comments on commit b7b5eb2

Please sign in to comment.