Skip to content

Commit

Permalink
Add direct lowering for aten.floor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691840280
  • Loading branch information
junjiang-lab authored and copybara-github committed Oct 31, 2024
1 parent 27438c0 commit 284fe62
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
4 changes: 1 addition & 3 deletions ai_edge_torch/generative/layers/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
zero_centered_gamma=config.zero_centered,
)
elif config.type == cfg.NormalizationType.LAYER_NORM:
return normalization.LayerNorm(
dim, config.epsilon, config.enable_hlfb
)
return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
elif config.type == cfg.NormalizationType.GROUP_NORM:
return normalization.GroupNorm(
config.group_num, dim, config.epsilon, config.enable_hlfb
Expand Down
7 changes: 7 additions & 0 deletions ai_edge_torch/odml_torch/lowerings/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
return stablehlo.divide(x, y)


# https://pytorch.org/docs/stable/generated/torch.floor.html
# https://openxla.org/stablehlo/spec#floor
@lower(torch.ops.aten.floor)
def _aten_floor(lctx, x: ir.Value, *, out=None) -> ir.Value:
return stablehlo.floor(x)


# Schema:
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
# start=None, SymInt? end=None, SymInt step=1) -> Tensor
Expand Down
1 change: 0 additions & 1 deletion ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def lower_by_torch_xla2(op):
lower_by_torch_xla2(torch.ops.aten.expm1)
lower_by_torch_xla2(torch.ops.aten.fill)
lower_by_torch_xla2(torch.ops.aten.flip)
lower_by_torch_xla2(torch.ops.aten.floor)
lower_by_torch_xla2(torch.ops.aten.fmod)
lower_by_torch_xla2(torch.ops.aten.full)
lower_by_torch_xla2(torch.ops.aten.full_like)
Expand Down
3 changes: 3 additions & 0 deletions ai_edge_torch/odml_torch/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def _run_export_and_compare(
# ("aten_fill_Scalar_0", torch.ops.aten.fill.Scalar, (rnd(torch.float32, (10, 10)), 0.123,), dict()),
("aten_flip_0", torch.ops.aten.flip, (rnd(torch.float32, (10, 10)), [0, 1],), dict()),
("aten_floor_0", torch.ops.aten.floor, (rnd(torch.float32, (10, 10)),), dict()),
("aten_floor_1", torch.ops.aten.floor, (rnd(torch.float32, (10, 10), -10, 0),), dict()),
("aten_floor_2", torch.ops.aten.floor, (rnd(torch.float32, (10, 10), 0, 10),), dict()),
("aten_floor_3", torch.ops.aten.floor, (rnd(torch.float32, (10, 10), -100, 100),), dict()),
("aten_fmod_Scalar_0", torch.ops.aten.fmod.Scalar, (rnd(torch.float32, (10, 10)), 0.123,), dict()),
("aten_fmod_Tensor_0", torch.ops.aten.fmod.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
("aten_full_0", torch.ops.aten.full, ((5, 5), 0.123,), dict()),
Expand Down

0 comments on commit 284fe62

Please sign in to comment.