Skip to content

Commit

Permalink
Try this out
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Dec 23, 2024
1 parent 340df89 commit 79c1522
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/jax/graphs/test_example_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jax
import pytest
from infra import run_graph_test_with_random_inputs
from infra import ComparisonConfig, run_graph_test_with_random_inputs
from jax import numpy as jnp


Expand All @@ -24,4 +24,9 @@ def example_graph(x: jax.Array, y: jax.Array) -> jax.Array:
],
)
def test_example_graph(x_shape: tuple, y_shape: tuple):
run_graph_test_with_random_inputs(example_graph, [x_shape, y_shape])
comparison_config = ComparisonConfig()
comparison_config.atol.disable()

run_graph_test_with_random_inputs(

Check failure on line 30 in tests/jax/graphs/test_example_graph.py

View workflow job for this annotation

GitHub Actions / TT-XLA Tests

test_example_graph.test_example_graph[x_shape0-y_shape0]

AssertionError: PCC comparison failed. Calculated pcc=-0.0005407017888501287
Raw output
device_output = Array([[1.0001237, 1.0001053, 1.0001491],
       [1.0001608, 1.0001174, 1.0001533],
       [1.0000875, 1.0000967, 1.0001339]], dtype=float32)
golden_output = Array([[1.6487212, 1.6487212, 1.6487212],
       [1.6487212, 1.6487212, 1.6487212],
       [1.6487212, 1.6487212, 1.6487212]], dtype=float32)
pcc_config = PccConfig(enabled=True, required_pcc=0.99)

    @run_on_cpu
    def compare_pcc(
        device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig
    ) -> None:
        assert isinstance(device_output, jax.Array) and isinstance(
            golden_output, jax.Array
        ), f"Currently only jax.Array is supported"
    
        # If tensors are really close, pcc will be nan. Handle that before calculating pcc.
        try:
>           compare_allclose(
                device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2)
            )

tests/infra/comparison.py:108: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/infra/device_runner.py:147: in wrapper
    return DeviceRunner.run_on_cpu(workload)
tests/infra/device_runner.py:26: in run_on_cpu
    return DeviceRunner._run_on_device(DeviceType.CPU, workload)
tests/infra/device_runner.py:72: in _run_on_device
    return device_workload.execute()
tests/infra/utils.py:27: in execute
    return self.executable(*self.args, **self.kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

device_output = Array([[1.0001237, 1.0001053, 1.0001491],
       [1.0001608, 1.0001174, 1.0001533],
       [1.0000875, 1.0000967, 1.0001339]], dtype=float32)
golden_output = Array([[1.6487212, 1.6487212, 1.6487212],
       [1.6487212, 1.6487212, 1.6487212],
       [1.6487212, 1.6487212, 1.6487212]], dtype=float32)
allclose_config = AllcloseConfig(enabled=True, rtol=0.01, atol=0.01)

    @run_on_cpu
    def compare_allclose(
        device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig
    ) -> None:
        assert isinstance(device_output, jax.Array) and isinstance(
            golden_output, jax.Array
        ), f"Currently only jax.Array is supported"
    
        allclose = jnp.allclose(
            device_output,
            golden_output,
            rtol=allclose_config.rtol,
            atol=allclose_config.atol,
        )
    
>       assert allclose, f"Allclose comparison failed."
E       AssertionError: Allclose comparison failed.

tests/infra/comparison.py:135: AssertionError

During handling of the above exception, another exception occurred:

x_shape = (3, 3), y_shape = (3, 3)

    @pytest.mark.parametrize(
        ["x_shape", "y_shape"],
        [
            [(3, 3), (3, 3)],
            [(32, 32), (32, 32)],
            [(64, 64), (64, 64)],
        ],
    )
    def test_example_graph(x_shape: tuple, y_shape: tuple):
        comparison_config = ComparisonConfig()
        comparison_config.atol.disable()
    
>       run_graph_test_with_random_inputs(
            example_graph, [x_shape, y_shape], comparison_config
        )

tests/jax/graphs/test_example_graph.py:30: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/infra/graph_tester.py:48: in run_graph_test_with_random_inputs
    tester.test_with_random_inputs(graph, input_shapes)
tests/infra/op_tester.py:45: in test_with_random_inputs
    self.test(workload)
tests/infra/op_tester.py:35: in test
    self._compare(tt_res, cpu_res)
tests/infra/base_tester.py:57: in _compare
    compare_pcc(device_output, golden_output, self._comparison_config.pcc)
tests/infra/device_runner.py:147: in wrapper
    return DeviceRunner.run_on_cpu(workload)
tests/infra/device_runner.py:26: in run_on_cpu
    return DeviceRunner._run_on_device(DeviceType.CPU, workload)
tests/infra/device_runner.py:72: in _run_on_device
    return device_workload.execute()
tests/infra/utils.py:27: in execute
    return self.executable(*self.args, **self.kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

device_output = Array([[1.0001237, 1.0001053, 1.0001491],
       [1.0001608, 1.0001174, 1.0001533],
       [1.0000875, 1.0000967, 1.0001339]], dtype=float32)
golden_output = Array([[1.6487212, 1.6487212, 1.6487212],
       [1.6487212, 1.6487212, 1.6487212],
       [1.6487212, 1.6487212, 1.6487212]], dtype=float32)
pcc_config = PccConfig(enabled=True, required_pcc=0.99)

    @run_on_cpu
    def compare_pcc(
        device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig
    ) -> None:
        assert isinstance(device_output, jax.Array) and isinstance(
            golden_output, jax.Array
        ), f"Currently only jax.Array is supported"
    
        # If tensors are really close, pcc will be nan. Handle that before calculating pcc.
        try:
            compare_allclose(
                device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2)
            )
        except AssertionError:
            pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten())
            pcc = jnp.min(pcc)
    
            assert (
>               pcc >= pcc_config.required_pcc
            ), f"PCC comparison failed. Calculated pcc={pcc}"
E           AssertionError: PCC comparison failed. Calculated pcc=-0.0005407017888501287

tests/infra/comparison.py:116: AssertionError

Check failure on line 30 in tests/jax/graphs/test_example_graph.py

View workflow job for this annotation

GitHub Actions / TT-XLA Tests

test_example_graph.test_example_graph[x_shape2-y_shape2]

AssertionError: PCC comparison failed. Calculated pcc=nan
Raw output
device_output = Array([[1.6390116, 1.6390126, 1.6390126, ..., 1.0001031, 1.0001142,
        1.0001161],
       [1.6390125, 1.6390126, ...    1.0001127],
       [1.0000914, 1.000103 , 1.0001153, ..., 1.0001247, 1.0001589,
        1.0000964]], dtype=float32)
golden_output = Array([[1.6487212, 1.6487212, 1.6487212, ..., 1.6487212, 1.6487212,
        1.6487212],
       [1.6487212, 1.6487212, ...    1.6487212],
       [1.6487212, 1.6487212, 1.6487212, ..., 1.6487212, 1.6487212,
        1.6487212]], dtype=float32)
pcc_config = PccConfig(enabled=True, required_pcc=0.99)

    @run_on_cpu
    def compare_pcc(
        device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig
    ) -> None:
        assert isinstance(device_output, jax.Array) and isinstance(
            golden_output, jax.Array
        ), f"Currently only jax.Array is supported"
    
        # If tensors are really close, pcc will be nan. Handle that before calculating pcc.
        try:
>           compare_allclose(
                device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2)
            )

tests/infra/comparison.py:108: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/infra/device_runner.py:147: in wrapper
    return DeviceRunner.run_on_cpu(workload)
tests/infra/device_runner.py:26: in run_on_cpu
    return DeviceRunner._run_on_device(DeviceType.CPU, workload)
tests/infra/device_runner.py:72: in _run_on_device
    return device_workload.execute()
tests/infra/utils.py:27: in execute
    return self.executable(*self.args, **self.kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

device_output = Array([[1.6390116, 1.6390126, 1.6390126, ..., 1.0001031, 1.0001142,
        1.0001161],
       [1.6390125, 1.6390126, ...    1.0001127],
       [1.0000914, 1.000103 , 1.0001153, ..., 1.0001247, 1.0001589,
        1.0000964]], dtype=float32)
golden_output = Array([[1.6487212, 1.6487212, 1.6487212, ..., 1.6487212, 1.6487212,
        1.6487212],
       [1.6487212, 1.6487212, ...    1.6487212],
       [1.6487212, 1.6487212, 1.6487212, ..., 1.6487212, 1.6487212,
        1.6487212]], dtype=float32)
allclose_config = AllcloseConfig(enabled=True, rtol=0.01, atol=0.01)

    @run_on_cpu
    def compare_allclose(
        device_output: Tensor, golden_output: Tensor, allclose_config: AllcloseConfig
    ) -> None:
        assert isinstance(device_output, jax.Array) and isinstance(
            golden_output, jax.Array
        ), f"Currently only jax.Array is supported"
    
        allclose = jnp.allclose(
            device_output,
            golden_output,
            rtol=allclose_config.rtol,
            atol=allclose_config.atol,
        )
    
>       assert allclose, f"Allclose comparison failed."
E       AssertionError: Allclose comparison failed.

tests/infra/comparison.py:135: AssertionError

During handling of the above exception, another exception occurred:

x_shape = (64, 64), y_shape = (64, 64)

    @pytest.mark.parametrize(
        ["x_shape", "y_shape"],
        [
            [(3, 3), (3, 3)],
            [(32, 32), (32, 32)],
            [(64, 64), (64, 64)],
        ],
    )
    def test_example_graph(x_shape: tuple, y_shape: tuple):
        comparison_config = ComparisonConfig()
        comparison_config.atol.disable()
    
>       run_graph_test_with_random_inputs(
            example_graph, [x_shape, y_shape], comparison_config
        )

tests/jax/graphs/test_example_graph.py:30: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/infra/graph_tester.py:48: in run_graph_test_with_random_inputs
    tester.test_with_random_inputs(graph, input_shapes)
tests/infra/op_tester.py:45: in test_with_random_inputs
    self.test(workload)
tests/infra/op_tester.py:35: in test
    self._compare(tt_res, cpu_res)
tests/infra/base_tester.py:57: in _compare
    compare_pcc(device_output, golden_output, self._comparison_config.pcc)
tests/infra/device_runner.py:147: in wrapper
    return DeviceRunner.run_on_cpu(workload)
tests/infra/device_runner.py:26: in run_on_cpu
    return DeviceRunner._run_on_device(DeviceType.CPU, workload)
tests/infra/device_runner.py:72: in _run_on_device
    return device_workload.execute()
tests/infra/utils.py:27: in execute
    return self.executable(*self.args, **self.kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

device_output = Array([[1.6390116, 1.6390126, 1.6390126, ..., 1.0001031, 1.0001142,
        1.0001161],
       [1.6390125, 1.6390126, ...    1.0001127],
       [1.0000914, 1.000103 , 1.0001153, ..., 1.0001247, 1.0001589,
        1.0000964]], dtype=float32)
golden_output = Array([[1.6487212, 1.6487212, 1.6487212, ..., 1.6487212, 1.6487212,
        1.6487212],
       [1.6487212, 1.6487212, ...    1.6487212],
       [1.6487212, 1.6487212, 1.6487212, ..., 1.6487212, 1.6487212,
        1.6487212]], dtype=float32)
pcc_config = PccConfig(enabled=True, required_pcc=0.99)

    @run_on_cpu
    def compare_pcc(
        device_output: Tensor, golden_output: Tensor, pcc_config: PccConfig
    ) -> None:
        assert isinstance(device_output, jax.Array) and isinstance(
            golden_output, jax.Array
        ), f"Currently only jax.Array is supported"
    
        # If tensors are really close, pcc will be nan. Handle that before calculating pcc.
        try:
            compare_allclose(
                device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2)
            )
        except AssertionError:
            pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten())
            pcc = jnp.min(pcc)
    
            assert (
>               pcc >= pcc_config.required_pcc
            ), f"PCC comparison failed. Calculated pcc={pcc}"
E           AssertionError: PCC comparison failed. Calculated pcc=nan

tests/infra/comparison.py:116: AssertionError
example_graph, [x_shape, y_shape], comparison_config
)

0 comments on commit 79c1522

Please sign in to comment.