From 548786af58d3e549d2a0ed0ce27122a463c949e0 Mon Sep 17 00:00:00 2001 From: Bhuminjay Soni Date: Fri, 3 Jan 2025 11:35:34 +0530 Subject: [PATCH] [JAX FE]: add support for jax.lax.logistic (#28240) **Overview:** This PR fixes #26576. **Testing:** - Tested the Updated code - Verified that other functionalities remain unaffected ![Screenshot from 2025-01-01 13-11-04](https://github.com/user-attachments/assets/5acfabc2-dded-4c65-b408-d4174fa3c025) **Dependencies:** - No dependencies on other pull requests **CC:** - @rkazants --------- Signed-off-by: 11happy Co-authored-by: Roman Kazantsev --- src/frontends/jax/src/op_table.cpp | 2 ++ tests/layer_tests/jax_tests/test_logistic.py | 37 ++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 tests/layer_tests/jax_tests/test_logistic.py diff --git a/src/frontends/jax/src/op_table.cpp b/src/frontends/jax/src/op_table.cpp index 9c492dfa3e119d..6ae0e6adc7c469 100644 --- a/src/frontends/jax/src/op_table.cpp +++ b/src/frontends/jax/src/op_table.cpp @@ -19,6 +19,7 @@ #include "openvino/op/not_equal.hpp" #include "openvino/op/reduce_max.hpp" #include "openvino/op/reduce_sum.hpp" +#include "openvino/op/sigmoid.hpp" #include "openvino/op/sqrt.hpp" #include "openvino/op/subtract.hpp" #include "openvino/op/tanh.hpp" @@ -94,6 +95,7 @@ const std::map get_supported_ops_jaxpr() { {"rsqrt", op::translate_rsqrt}, {"reshape", op::translate_reshape}, {"select_n", op::translate_select_n}, + {"logistic", op::translate_1to1_match_1_input}, {"slice", op::translate_slice}, {"square", op::translate_square}, {"sqrt", op::translate_1to1_match_1_input}, diff --git a/tests/layer_tests/jax_tests/test_logistic.py b/tests/layer_tests/jax_tests/test_logistic.py new file mode 100644 index 00000000000000..19e5f0a81eba30 --- /dev/null +++ b/tests/layer_tests/jax_tests/test_logistic.py @@ -0,0 +1,37 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +from jax_layer_test_class import JaxLayerTest + +rng = np.random.default_rng(5402) + + +class TestLogistic(JaxLayerTest): + def _prepare_input(self): + + input = jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type)) + return [input] + + def create_model(self, input_shape, input_type): + self.input_shape = input_shape + self.input_type = input_type + + def jax_logistic(input): + return jax.lax.logistic(input) + + return jax_logistic, None, 'logistic' + + @pytest.mark.parametrize("input_shape", [[2], [3, 4], [5,6,7]]) + @pytest.mark.parametrize("input_type", [np.float32, np.float64]) + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_jax_fe + def test_logistic(self, ie_device, precision, ir_version, input_shape, input_type): + self._test(*self.create_model(input_shape, input_type), + ie_device, precision, + ir_version)