From 79c15227ba17c9a49a7be5bd54b4e5d162fd2a47 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Mon, 23 Dec 2024 13:15:08 +0000 Subject: [PATCH] Try this out --- tests/jax/graphs/test_example_graph.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/jax/graphs/test_example_graph.py b/tests/jax/graphs/test_example_graph.py index b76d897..0ec523f 100644 --- a/tests/jax/graphs/test_example_graph.py +++ b/tests/jax/graphs/test_example_graph.py @@ -4,7 +4,7 @@ import jax import pytest -from infra import run_graph_test_with_random_inputs +from infra import ComparisonConfig, run_graph_test_with_random_inputs from jax import numpy as jnp @@ -24,4 +24,9 @@ def example_graph(x: jax.Array, y: jax.Array) -> jax.Array: ], ) def test_example_graph(x_shape: tuple, y_shape: tuple): - run_graph_test_with_random_inputs(example_graph, [x_shape, y_shape]) + comparison_config = ComparisonConfig() + comparison_config.atol.disable() + + run_graph_test_with_random_inputs( + example_graph, [x_shape, y_shape], comparison_config + )