Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX]: Support jax.lax.select_n operation for JAX #28025

Merged
merged 14 commits into from
Dec 31, 2024
Merged
48 changes: 48 additions & 0 deletions src/frontends/jax/src/op/select_n.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/jax/node_context.hpp";
11happy marked this conversation as resolved.
Show resolved Hide resolved
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/gather_elements.hpp"

#include "utils.hpp";

using namespace ov::op;

namespace ov {
namespace frontend {
namespace jax {
namespace op {

OutputVector translate_select_n(const NodeContext& context) {
num_inputs_check(context, 2);
auto num_inputs = static_cast<int>(context.get_input_size());
Output<Node> which = context.get_input(0);
if (which.get_element_type() == element::boolean) {
which = std::make_shared<v0::Convert>(which, element::i32);
}
OutputVector cases_vector(num_inputs - 1);
for(int ind = 1; ind < num_inputs; ++ind) {
cases_vector[ind - 1] = context.get_input(ind);
}

Output<Node> cases = std::make_shared<v0::Concat>(cases_vector, 0);
auto which_shape = which.get_shape();
rkazants marked this conversation as resolved.
Show resolved Hide resolved
std::vector<int64_t> cases_reshape_shape = {num_inputs-1,which_shape[0]};
std::vector<int64_t> which_reshape_shape = {1,which_shape[0]};

cases = std::make_shared<v1::Reshape>(cases, ov::op::v0::Constant::create(element::i64, Shape{2}, cases_reshape_shape), false);
which = std::make_shared<v1::Reshape>(which, ov::op::v0::Constant::create(element::i64, Shape{2}, which_reshape_shape), false);
Output<Node> result = std::make_shared<v6::GatherElements>(cases, which, 0);
return {result};

};

} // namespace op
} // namespace jax
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ OP_CONVERTER(translate_reduce_window_max);
OP_CONVERTER(translate_reduce_window_sum);
OP_CONVERTER(translate_reshape);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_select_n);
OP_CONVERTER(translate_slice);
OP_CONVERTER(translate_squeeze);
OP_CONVERTER(translate_transpose);
Expand Down Expand Up @@ -91,6 +92,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
{"transpose", op::translate_transpose},
{"rsqrt", op::translate_rsqrt},
{"reshape", op::translate_reshape},
{"select_n", op::translate_select_n},
{"slice", op::translate_slice},
{"sqrt", op::translate_1to1_match_1_input<v0::Sqrt>},
{"squeeze", op::translate_squeeze},
Expand Down
45 changes: 45 additions & 0 deletions tests/layer_tests/jax_tests/test_select_n.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import jax
11happy marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import pytest
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(5402)


class TestSelectN(JaxLayerTest):
def _prepare_input(self):
self.case_num = 2 if (self.input_type == np.bool or self.input_type == bool) else self.case_num
cases = []
which = rng.uniform(0,self.case_num, self.input_shape).astype(self.input_type)
which = np.array(which)
for i in range(self.case_num):
cases.append(jnp.array(rng.uniform(i*10, (i+1)*10, self.input_shape).astype(self.input_type)))
rkazants marked this conversation as resolved.
Show resolved Hide resolved
cases = np.array(cases)
return (which, cases)

def create_model(self, input_shape, input_type, case_num):
self.input_shape = input_shape
self.input_type = input_type
self.case_num = case_num

def jax_select_n(which, cases):
return jax.lax.select_n(which, *cases)

return jax_select_n, None, 'select_n'


@pytest.mark.parametrize("input_shape", [1,2,3,4,5,6,7,8,9,10])
rkazants marked this conversation as resolved.
Show resolved Hide resolved
rkazants marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("input_type", [np.int32, np.int64, bool])
rkazants marked this conversation as resolved.
Show resolved Hide resolved
rkazants marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("case_num", [1,2,3,4,5,6,7,8,9,10])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let use case_num starting from 2. Value 1 is impractical.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_jax_fe
def test_select_n(self, ie_device, precision, ir_version, input_shape, input_type, case_num):
self._test(*self.create_model(input_shape, input_type, case_num),
ie_device, precision,
ir_version)
Loading