Skip to content

Commit

Permalink
[JAX FE]: add support for jax.lax.logistic (openvinotoolkit#28240)
Browse files Browse the repository at this point in the history
**Overview:**
This PR fixes openvinotoolkit#26576.

**Testing:**
- Tested the Updated code
- Verified that other functionalities remain unaffected

![Screenshot from 2025-01-01
13-11-04](https://github.com/user-attachments/assets/5acfabc2-dded-4c65-b408-d4174fa3c025)

**Dependencies:**
- No dependencies on other pull requests

**CC:**
- @rkazants

---------

Signed-off-by: 11happy <[email protected]>
Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
11happy and rkazants authored Jan 3, 2025
1 parent 2e24dfa commit 548786a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "openvino/op/not_equal.hpp"
#include "openvino/op/reduce_max.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/sigmoid.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/tanh.hpp"
Expand Down Expand Up @@ -94,6 +95,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
{"rsqrt", op::translate_rsqrt},
{"reshape", op::translate_reshape},
{"select_n", op::translate_select_n},
{"logistic", op::translate_1to1_match_1_input<v0::Sigmoid>},
{"slice", op::translate_slice},
{"square", op::translate_square},
{"sqrt", op::translate_1to1_match_1_input<v0::Sqrt>},
Expand Down
37 changes: 37 additions & 0 deletions tests/layer_tests/jax_tests/test_logistic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import jax
import numpy as np
import pytest
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(5402)


class TestLogistic(JaxLayerTest):
def _prepare_input(self):

input = jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type))
return [input]

def create_model(self, input_shape, input_type):
self.input_shape = input_shape
self.input_type = input_type

def jax_logistic(input):
return jax.lax.logistic(input)

return jax_logistic, None, 'logistic'

@pytest.mark.parametrize("input_shape", [[2], [3, 4], [5,6,7]])
@pytest.mark.parametrize("input_type", [np.float32, np.float64])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_jax_fe
def test_logistic(self, ie_device, precision, ir_version, input_shape, input_type):
self._test(*self.create_model(input_shape, input_type),
ie_device, precision,
ir_version)

0 comments on commit 548786a

Please sign in to comment.