Skip to content

Commit

Permalink
fix slice_scatter lowering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675699074
  • Loading branch information
chunnienc authored and copybara-github committed Sep 17, 2024
1 parent 96f6b84 commit 47e20da
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
18 changes: 14 additions & 4 deletions ai_edge_torch/odml_torch/lowerings/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,25 +212,35 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
# - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
@lower(torch.ops.aten.slice_scatter)
def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
start = start or 0
end = end or self.type.shape[dim]
start = start if start is not None else 0
end = end if end is not None else self.type.shape[dim]

start, end = np.clip(
[start, end], -self.type.shape[dim], self.type.shape[dim]
)

if start < 0:
start = self.type.shape[dim] + start
if end < 0:
end = self.type.shape[dim] + end

end = start + step * math.ceil((end - start) / step) - (step - 1)
if end <= start or np.prod(src.type.shape) == 0:
return self

end = start + step * math.ceil((end - start) / step) - (step - 1)
padding_low = start
padding_high = self.type.shape[dim] - end
interior_padding = step - 1

rank = len(self.type.shape)
src = stablehlo.pad(
src,
utils.splat(0, src.type.element_type, []),
edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
interior_padding=[
interior_padding if i == dim else 0 for i in range(rank)
],
)
pred = np.ones(self.type.shape, dtype=np.bool_)
pred[*[
Expand Down
1 change: 1 addition & 0 deletions ai_edge_torch/odml_torch/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def _run_export_and_compare(
("aten_slice_scatter_3", torch.ops.aten.slice_scatter, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (3, 10)), 0, 2, 8, 2), dict()),
("aten_slice_scatter_4", torch.ops.aten.slice_scatter, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (2, 10)), 0, 2, 7, 3), dict()),
("aten_slice_scatter_5", torch.ops.aten.slice_scatter, (rnd(torch.float32, (0, 10)), rnd(torch.float32, (0, 3)), 1, 0, -7), dict()),
("aten_slice_scatter_6", torch.ops.aten.slice_scatter, (rnd(torch.float32, (8, 3, 3)), rnd(torch.float32, (0, 3, 3)), 0, -8, 0), dict()),
("aten_slice_Tensor_0", torch.ops.aten.slice.Tensor, (rnd(torch.float32, (10, 10)), 1,), dict()),
("aten__softmax_0", torch.ops.aten._softmax, (rnd(torch.float32, (10, 10)), 1, False,), dict()),
("aten_split_copy_Tensor_0", torch.ops.aten.split_copy.Tensor, (rnd(torch.float32, (10, 10)), 2,), dict()),
Expand Down

0 comments on commit 47e20da

Please sign in to comment.