Skip to content

Commit

Permalink
Merge pull request #232 from balanprasanth:automation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698536260
  • Loading branch information
chunnienc committed Nov 21, 2024
2 parents 2e6e6b2 + b9022c3 commit e16cd9d
Show file tree
Hide file tree
Showing 11 changed files with 271 additions and 10 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/auto-assignment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: auto-assignment
on:
issues:
types:
- opened

permissions:
contents: read
issues: write
pull-requests: write

jobs:
welcome:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/github-script@v7
with:
script: |
const script = require('./\.github/workflows/scripts/auto-assignment.js')
script({github, context})
2 changes: 2 additions & 0 deletions .github/workflows/mark_stale.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ jobs:
stale-issue-label: "status:stale"
close-issue-reason: completed
any-of-labels: "status:awaiting user response,status:more data needed"
# List of labels to remove when issues/PRs unstale.
labels-to-remove-when-unstale: 'status:awaiting user response,status:stale'
stale-issue-message: >
Marking this issue as stale since it has been open for 7 days with no activity.
This issue will be closed if no further activity occurs.
Expand Down
40 changes: 40 additions & 0 deletions .github/workflows/scripts/auto-assignment.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/**
Automatically assign issues and PRs to users in the `assigneesList`
* on a rotating basis.
@param {!object}
GitHub objects can call GitHub APIs using their built-in library functions.
The context object contains issue and PR details.
*/

module.exports = async ({github, context}) => {
let issueNumber;
let assigneesList;
// Is this an issue? If so, assign the issue number. Otherwise, assign the PR
// number.
if (context.payload.issue) {
assigneesList = ['pkgoogle', 'gaikwadrahul8']; // for issues
issueNumber = context.payload.issue.number;
} else {
assigneesList = []; // for PRs
issueNumber = context.payload.number;
}
console.log('assignee list', assigneesList);
console.log('entered auto assignment for this issue: ', issueNumber);
if (!assigneesList.length) {
console.log('No assignees found for this repo.');
return;
}
let noOfAssignees = assigneesList.length;
let selection = issueNumber % noOfAssignees;
let assigneeForIssue = assigneesList[selection];

console.log(
'issue Number = ', issueNumber + ' , assigning to: ', assigneeForIssue);
return github.rest.issues.addAssignees({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
assignees: [assigneeForIssue],
});
};
107 changes: 107 additions & 0 deletions ai_edge_torch/generative/test/test_custom_dus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A suite of tests to validate the Dynamic Update Slice Custom Op."""

from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.model_config as cfg
import torch
from torch import nn

from absl.testing import absltest as googletest, parameterized


def updated_slice_matches(buffer, update, index):
indexer = [slice(i, i + d) for i, d in zip(index, update.shape)]
buf = buffer[indexer]
return torch.allclose(buf, update)


def intT(x):
return torch.tensor(x).int()


class DUSMod(nn.Module):

def forward(self, buffer, update, index):
out = dynamic_update_slice(buffer, update, index)
out = out * 2
return out


@googletest.skip('Enable this when odml_torch is default b/373387583')
class TestCustomDUS(parameterized.TestCase):

@parameterized.named_parameters(
(
'DUS_whole_buffer',
torch.randn(1, 1280, 4, 64),
torch.randn([1, 1024, 4, 64]),
[intT(0), intT(0), intT(0), intT(0)],
),
(
'DUS_kv_example',
torch.randn(2, 1280, 4, 64),
torch.randn([2, 1024, 4, 64]),
[intT(0), intT(0), intT(0), intT(0)],
),
(
'DUS_3d',
torch.randn(2, 256, 4, 64),
torch.randn([2, 256, 2, 64]),
[intT(0), intT(0), intT(2), intT(0)],
),
(
'DUS_3d_v2',
torch.randn(2, 256, 4, 64),
torch.randn([2, 256, 3, 64]),
[intT(0), intT(0), intT(1), intT(0)],
),
(
'DUS_3d_v3',
torch.randn(6, 8, 32),
torch.randn([6, 3, 32]),
[intT(0), intT(5), intT(0)],
),
(
'DUS_2d',
torch.randn(8, 32),
torch.randn([8, 12]),
[intT(0), intT(20)],
),
)
def test_opcheck_dynamic_update_slice(self, buffer, update, indices):
torch.library.opcheck(dynamic_update_slice, (buffer, update, indices))
out = dynamic_update_slice(buffer, update, indices)
self.assertTrue(updated_slice_matches(out, update, indices))

def test_exported_program(self):
buffer = torch.randn(1, 1280, 4, 64)
update = torch.randn([1, 1024, 4, 64])
index = [intT(0), intT(0), intT(0), intT(0)]
dm = DUSMod()
ep = torch.export.export(dm, (buffer, update, index))
dus_in_exported_program = False
for node in ep.graph.nodes:
if node.op == 'call_function':
if node.target.__name__.startswith('dynamic_update_slice'):
dus_in_exported_program = True
break

self.assertTrue(dus_in_exported_program)


if __name__ == '__main__':
googletest.main()
56 changes: 56 additions & 0 deletions ai_edge_torch/generative/utilities/dynamic_update_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Common utility functions for data loading etc.
from dataclasses import dataclass
import glob
import os
from typing import Sequence
from ai_edge_torch.odml_torch import lowerings
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
import torch


# Use torch.library.custom_op to define a new custom operator.
# TODO: Update impl for multiple non-trivial start_indices
@torch.library.custom_op("ai_edge_torch::dynamic_update_slice", mutates_args=())
def dynamic_update_slice(
in_tensor: torch.Tensor,
update: torch.Tensor,
start_indices: Sequence[torch.Tensor],
) -> torch.Tensor:
compare_size = torch.tensor(in_tensor.size()) == torch.tensor(update.size())
mismatch = torch.nonzero(~compare_size, as_tuple=False)
dim = mismatch[0].item() if len(mismatch) > 0 else 0
start = start_indices[dim].item()
end = start + update.shape[dim]
indices = torch.arange(start, end).to(torch.long)
return in_tensor.index_copy(dim, indices, update)


# Use register_fake to add a ``FakeTensor`` kernel for the operator
@dynamic_update_slice.register_fake
def _(in_tensor, update, start_indices):
return in_tensor.clone().detach()


@lowerings.lower(torch.ops.ai_edge_torch.dynamic_update_slice)
def _dynamic_update_slice_lower(
lctx,
in_tensor: ir.Value,
update: ir.Value,
start_indices: Sequence[ir.Value],
):
return stablehlo.dynamic_update_slice(in_tensor, update, start_indices)
1 change: 1 addition & 0 deletions ai_edge_torch/lowertools/odml_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def merged_bundle_to_tfl_model(
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
converter._experimental_enable_composite_direct_lowering = True
converter._experimental_enable_dynamic_update_slice = True
converter.model_origin_framework = "PYTORCH"

conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
Expand Down
36 changes: 34 additions & 2 deletions ai_edge_torch/odml_torch/jax_bridge/_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import jax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import hlo as stablehlo
import torch.utils._pytree as pytree

# Jax double (64bit) precision is required to generate StableHLO mlir with
Expand Down Expand Up @@ -143,8 +144,39 @@ def wrapped(lctx, *args, **kwargs):
ir_inputs = []

results = func.CallOp(cloned_func, ir_inputs).results

if lctx.node is None:
return results[0] if len(results) == 1 else results

out_avals = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")

if out_avals is None:
return results[0] if len(results) == 1 else results

def sanitize_result_elty(result, aval):
# JAX implementation may not respect aten op's output dtype. For example,
# JAX may implement a slightly different dtype upcast rules, leads to
# different result's dtype from bridged lowering and torch op output.
# Here we add an additional `stablehlo.convert` op when dtype does not
# match, to ensure the lowering's result dtype will always be the same
# as torch op's output dtype.
if aval is None:
return result

target_elty = export_utils.torch_dtype_to_ir_element_type(
lctx.ir_context, aval.dtype
)
if result.type.element_type == target_elty:
return result
return stablehlo.convert(
ir.RankedTensorType.get(result.type.shape, target_elty), result
)

if len(results) == 1:
return results[0]
return results
return sanitize_result_elty(results[0], out_avals)
return [
sanitize_result_elty(result, aval)
for result, aval in zip(results, out_avals)
]

return wrapped
8 changes: 4 additions & 4 deletions ai_edge_torch/odml_torch/lowerings/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,18 +223,18 @@ def _aten_floor(lctx, x: ir.Value, *, out=None) -> ir.Value:
def _aten_cat(lctx: LoweringContext, tensors, dim=0):
assert tensors
non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0]
out_meta = lctx.node.meta["tensor_meta"]
out_aval = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
if not non_empty_tensors:
return utils.splat(
0,
export_utils.torch_dtype_to_ir_element_type(
lctx.ir_context, out_meta.dtype
lctx.ir_context, out_aval.dtype
),
out_meta.shape,
out_aval.shape,
)

if dim < 0:
dim = dim + len(out_meta.shape)
dim = dim + len(out_aval.shape)
dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim)

return stablehlo.concatenate(non_empty_tensors, dim)
Expand Down
1 change: 0 additions & 1 deletion ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def lower_by_torch_xla2(op):
lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
lower_by_torch_xla2(torch.ops.aten.native_group_norm)
lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
lower_by_torch_xla2(torch.ops.aten.ne)
lower_by_torch_xla2(torch.ops.aten.neg)
Expand Down
1 change: 1 addition & 0 deletions ai_edge_torch/odml_torch/lowerings/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def register(self, op, lowering):
torch.ops.aten._adaptive_avg_pool2d,
torch.ops.aten._adaptive_avg_pool3d,
torch.ops.aten.grid_sampler_2d,
torch.ops.aten.native_group_norm,
torch.ops.aten.native_dropout,
torch.ops.aten.reflection_pad1d,
torch.ops.aten.reflection_pad2d,
Expand Down
8 changes: 5 additions & 3 deletions ai_edge_torch/odml_torch/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def forward(self, *export_args):
def rnd(dtype, shape, min_v=None, max_v=None):
"""Shortcut for creating a random torch tensor."""
if dtype in (torch.int32, torch.int64, torch.bool):
min_v = min_v if min_v else 0
min_v = min_v if min_v else 1
max_v = max_v if max_v else 10
return torch.randint(min_v, max_v, shape).to(dtype)
else:
Expand Down Expand Up @@ -137,7 +137,7 @@ def _run_export_and_compare(

@parameterized.named_parameters(
# fmt: off
# pyformat: disable
# pyformat: disabledef
("aten_abs_0", torch.ops.aten.abs, (rnd(torch.float32, (10, 10)),), dict()),
("aten_acos_0", torch.ops.aten.acos, (rnd(torch.float32, (10, 10)),), dict()),
("aten_acosh_0", torch.ops.aten.acosh, (rnd(torch.float32, (10, 10)),), dict()),
Expand Down Expand Up @@ -230,7 +230,8 @@ def _run_export_and_compare(
("aten_div_Scalar_0", torch.ops.aten.div.Scalar, (rnd(torch.float32, (10, 10)), 0.5,), dict()),
("aten_div_Scalar_mode_0", torch.ops.aten.div.Scalar_mode, (rnd(torch.float32, (10, 10)), 0.123,), {"rounding_mode": "trunc"}),
("aten_div_Tensor_0", torch.ops.aten.div.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
("aten_div_Tensor_mode_0", torch.ops.aten.div.Tensor_mode, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_div_Tensor_mode_trunc_0", torch.ops.aten.div.Tensor_mode, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_div_Tensor_mode_trunc_1", torch.ops.aten.div.Tensor_mode, (rnd(torch.int32, (10, 10)), rnd(torch.int32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_embedding_0", torch.ops.aten.embedding, (rnd(torch.float32, (10, 10)), rnd(torch.int64, (10,)),), dict()),
("aten_eq_Scalar_2", torch.ops.aten.eq.Scalar, (rnd(torch.float32, (10, 10)), 1,), dict()),
("aten_eq_Tensor_0", torch.ops.aten.eq.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
Expand Down Expand Up @@ -301,6 +302,7 @@ def _run_export_and_compare(
("aten__native_batch_norm_legit_no_training_0", torch.ops.aten._native_batch_norm_legit_no_training, (rnd(torch.float32, (10, 10)), None, None, rnd(torch.float32, (10,)), rnd(torch.float32, (10,)), 1.0, 1.0,), dict()),
# ("aten_native_dropout_0", torch.ops.aten.native_dropout, (rnd(torch.float32, (10, 10)), 1.0, True,), dict()),
("aten_native_group_norm_0", torch.ops.aten.native_group_norm, (rnd(torch.float32, (1, 3, 2, 10)), None, None, 1, 3, 20, 1, 0.0,), dict()),
("aten_native_group_norm_1", torch.ops.aten.native_group_norm, (rnd(torch.float32, (1, 3, 2, 10)), rnd(torch.float32, (3,)), rnd(torch.float32, (3,)), 1, 3, 20, 1, 0.0,), dict()),
("aten_native_layer_norm_0", torch.ops.aten.native_layer_norm, (rnd(torch.float32, (1, 3, 2, 10)), [1, 3, 2, 10], None, None, 0.0,), dict()),
("aten_native_layer_norm_1", torch.ops.aten.native_layer_norm, (rnd(torch.float32, (1, 3, 2, 10)), [3, 2, 10], None, None, 0.0,), dict()),
("aten_native_layer_norm_2", torch.ops.aten.native_layer_norm, (rnd(torch.float32, (2, 3, 2, 10)), [2, 10], None, None, 0.0,), dict()),
Expand Down

0 comments on commit e16cd9d

Please sign in to comment.