From 6f1e3e5b7bf85c527ab430d2dd5f9f74bb01be00 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Fri, 22 Nov 2024 17:42:04 +0000 Subject: [PATCH] Rerun tests to update their status. Updated every one to use parametrize. --- tests/TTIR/test_array_add.py | 38 ----- tests/TTIR/test_basic_ops.py | 198 ++++++++++++++++----------- tests/TTIR/test_conv2d.py | 32 +++-- tests/TTIR/test_maxpool2d.py | 10 +- tests/TTIR/test_mnist.py | 42 +++--- tests/TTIR/test_reshape.py | 8 +- tests/TTIR/test_scalar_add.py | 22 --- tests/TTIR/test_simple_regression.py | 19 +-- 8 files changed, 181 insertions(+), 188 deletions(-) delete mode 100644 tests/TTIR/test_array_add.py delete mode 100644 tests/TTIR/test_scalar_add.py diff --git a/tests/TTIR/test_array_add.py b/tests/TTIR/test_array_add.py deleted file mode 100644 index 72d815d..0000000 --- a/tests/TTIR/test_array_add.py +++ /dev/null @@ -1,38 +0,0 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import jax -import jax.numpy as jnp - -from infrastructure import verify_module - - -def test_2x2_array_add(): - def module_add(a, b): - return a + b - - verify_module(module_add, [(2, 2), (2, 2)]) - - -def test_3x2_array_add(): - def module_add(a, b): - return a + b - - verify_module(module_add, [(3, 2), (3, 2)]) - - -@pytest.mark.parametrize("rank", [1, 2, 3, 4, 5, 6]) -def test_module_add(rank): - def module_add(a, b): - c = a + a - d = b + b - return c + d - - input_shape = [] - for i in range(rank): - input_shape.insert(0, 32 if i < 2 else 1) - - input_shape = tuple(input_shape) - verify_module(module_add, [input_shape, input_shape]) diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py index 37116f5..9c68b0b 100644 --- a/tests/TTIR/test_basic_ops.py +++ b/tests/TTIR/test_basic_ops.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC # # SPDX-License-Identifier: Apache-2.0 -# import pytest import jax @@ -11,31 +10,69 @@ from infrastructure import verify_module -def test_abs_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) +def test_abs_op(input_shapes): def module_abs(a): return jnp.abs(a) - verify_module(module_abs, [(3, 3)]) - verify_module(module_abs, [(3, 3, 3)]) + verify_module(module_abs, input_shapes) + + +@pytest.mark.parametrize("input_shapes", [[(2, 2), (2, 2)], [(3, 2), (3, 2)]]) +def test_array_add(input_shapes): + def module_add(a, b): + return a + b + + verify_module(module_add, input_shapes) + + +@pytest.mark.parametrize("rank", [1, 2, 3, 4, 5, 6]) +def test_module_add(rank): + def module_add(a, b): + c = a + a + d = b + b + return c + d + + input_shape = [] + for i in range(rank): + input_shape.insert(0, 32 if i < 2 else 1) + + input_shape = tuple(input_shape) + verify_module(module_add, [input_shape, input_shape]) + +# TODO might be outdated. +def test_scalar_add(): + def module_add(a, b): + return a + b -# Broadcasted values are incorrect -@pytest.mark.skip("Broadcasted values are incorrect") -def test_broadcast_op(): + a = jnp.float32(5.0) + b = jnp.array(6.0) + tt_graph = jax.jit(module_add) + res = tt_graph(a, b) + cpu_graph = jax.jit(module_add, backend="cpu") + res_cpu = cpu_graph(a, b) + assert jnp.allclose(res, res_cpu) + + +@pytest.mark.parametrize("input_shapes", [[(2, 1)]]) +@pytest.mark.skip( + "Broadcasted values are incorrect. " + "Fails with: AssertionError: PCC is 0.37796446681022644 which is less than 0.99" +) +def test_broadcast_op(input_shapes): def module_broadcast(a): return jnp.broadcast_to(a, (2, 4)) - verify_module(module_broadcast, [(2, 1)]) + verify_module(module_broadcast, input_shapes) -def test_cbrt_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) +def test_cbrt_op(input_shapes): def module_cbrt(a): return jax.lax.cbrt(a) - verify_module( - module_cbrt, [(3, 3)], required_atol=2e-2 - ) # ATOL is 0.010040640830993652 - verify_module(module_cbrt, [(3, 3, 3)], required_atol=2e-2) + verify_module(module_cbrt, input_shapes, required_atol=2e-2) def test_concat_op(): @@ -70,11 +107,9 @@ def module_concat_dim_3(x, y): ) # output shape: (32, 32, 32, 96) -# error: 'ttir.constant' op failed to verify that all of {value, result} have same shape -@pytest.mark.skip( - "Index is out of bounds for the rank, should be between 0 and 0 however is 18446744073709551615" -) -def test_constant_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)]]) +@pytest.mark.skip("AssertionError: ATOL is 21574.4375 which is greater than 0.01") +def test_constant_op(input_shapes): def module_constant_zeros(a): zeros = jnp.zeros(a.shape) return zeros @@ -83,149 +118,158 @@ def module_constant_ones(a): ones = jnp.ones(a.shape) return ones + verify_module(module_constant_zeros, input_shapes) + verify_module(module_constant_ones, input_shapes) + + +@pytest.mark.parametrize("input_shapes", [[(3, 3)]]) +@pytest.mark.skip("Fails due to: error: failed to legalize operation 'ttir.constant'") +def test_constant_op_multi_dim(input_shapes): def module_constant_multi(a): multi = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) return multi - verify_module(module_constant_zeros, [(3, 3)]) - verify_module(module_constant_ones, [(3, 3)]) - verify_module(module_constant_multi, [(3, 3)]) + verify_module(module_constant_multi, input_shapes) -def test_convert_op(): +@pytest.mark.parametrize("input_shapes", [[(2, 2)], [(4, 4, 4)]]) +def test_convert_op(input_shapes): def module_convert(a): return jax.lax.convert_element_type(a, jnp.bfloat16) - verify_module(module_convert, [(2, 2)]) - verify_module(module_convert, [(4, 4, 4)]) + verify_module(module_convert, input_shapes) -def test_div_op(): +@pytest.mark.parametrize( + ["input_shapes", "required_atol"], + [([(3, 3), (3, 3)], 0.01), ([(3, 3, 3), (3, 3, 3)], 35e-2)], +) +def test_div_op(input_shapes, required_atol): def module_div(a, b): return a / b - verify_module(module_div, [(3, 3), (3, 3)]) - verify_module(module_div, [(3, 3, 3), (3, 3, 3)], required_atol=35e-2) + verify_module(module_div, input_shapes, required_atol=required_atol) -def test_dot_general_op(): +@pytest.mark.parametrize( + "input_shapes", + [[(2, 1), (1, 2)], [(1, 2), (2, 1)]], +) +def test_dot_general_op(input_shapes): def module_dot_general(a, b): return jnp.dot(a, b) - verify_module(module_dot_general, [(2, 1), (1, 2)]) - verify_module(module_dot_general, [(1, 2), (2, 1)]) + verify_module(module_dot_general, input_shapes) # Exponential generate slightly different values, so using higher ATOL value. # see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199) -def test_exp_op(): +@pytest.mark.parametrize( + ["input_shapes", "required_atol"], [([(3, 3)], 20e-2), ([(3, 3, 3)], 25e-2)] +) +def test_exp_op(input_shapes, required_atol): def module_exp(a): return jnp.exp(a) - verify_module(module_exp, [(3, 3)], required_atol=20e-2) - verify_module(module_exp, [(3, 3, 3)], required_atol=25e-2) + verify_module(module_exp, input_shapes, required_atol=required_atol) -def test_maximum_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3), (3, 3)], [(3, 3, 3), (3, 3, 3)]]) +def test_maximum_op(input_shapes): def module_maximum(a, b): return jnp.maximum(a, b) - verify_module(module_maximum, [(3, 3), (3, 3)]) - verify_module(module_maximum, [(3, 3, 3), (3, 3, 3)]) + verify_module(module_maximum, input_shapes) -def test_multiply_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3), (3, 3)], [(3, 3, 3), (3, 3, 3)]]) +def test_multiply_op(input_shapes): def module_multiply(a, b): return a * b - verify_module(module_multiply, [(3, 3), (3, 3)]) - verify_module(module_multiply, [(3, 3, 3), (3, 3, 3)]) + verify_module(module_multiply, input_shapes) -def test_negate_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) +def test_negate_op(input_shapes): def module_negate(a): return -a - verify_module(module_negate, [(3, 3)]) - verify_module(module_negate, [(3, 3, 3)]) + verify_module(module_negate, input_shapes) # Reduce is failing due to error in constant. +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) @pytest.mark.skip("keepdim=False is not supported") -def test_reduce_op(): +def test_reduce_op(input_shapes): def module_reduce_max(a): return jnp.max(a) def module_reduce_sum(a): return jnp.sum(a) - verify_module(module_reduce_max, [(3, 3)]) - verify_module(module_reduce_max, [(3, 3, 3)]) + verify_module(module_reduce_max, input_shapes) + verify_module(module_reduce_sum, input_shapes) - verify_module(module_reduce_sum, [(3, 3)]) - verify_module(module_reduce_sum, [(3, 3, 3)]) - -def test_rsqrt_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) +def test_rsqrt_op(input_shapes): def module_rsqrt(a): return jax.lax.rsqrt(a) - verify_module(module_rsqrt, [(3, 3)]) - verify_module(module_rsqrt, [(3, 3, 3)]) + verify_module(module_rsqrt, input_shapes) # Needs to have a bigger atol due to inaccuracies in the exp op on tt-metal # see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199) -def test_expm1_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) +def test_expm1_op(input_shapes): def module_expm1(a): return jax.lax.expm1(a) - verify_module(module_expm1, [(3, 3)], required_atol=20e-2) - verify_module(module_expm1, [(3, 3, 3)], required_atol=20e-2) + verify_module(module_expm1, input_shapes, required_atol=20e-2) -def test_log1p_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) +def test_log1p_op(input_shapes): def module_log1p(a): return jax.lax.log1p(a) - verify_module(module_log1p, [(3, 3)], required_atol=2e-2) - verify_module(module_log1p, [(3, 3, 3)], required_atol=2e-2) + verify_module(module_log1p, input_shapes, required_atol=2e-2) -def test_sign_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) +def test_sign_op(input_shapes): def module_sign(a): return jax.lax.sign(a) - verify_module(module_sign, [(3, 3)]) - verify_module(module_sign, [(3, 3, 3)]) + verify_module(module_sign, input_shapes) -def test_sqrt_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) +def test_sqrt_op(input_shapes): def module_sqrt(a): return jnp.sqrt(a) - verify_module(module_sqrt, [(3, 3)]) - verify_module(module_sqrt, [(3, 3, 3)]) + verify_module(module_sqrt, input_shapes) -def test_sub_op(): +@pytest.mark.parametrize("input_shapes", [[(3, 3), (3, 3)], [(3, 3, 3), (3, 3, 3)]]) +def test_sub_op(input_shapes): def module_sub(a, b): return a - b - verify_module(module_sub, [(3, 3), (3, 3)]) - verify_module(module_sub, [(3, 3, 3), (3, 3, 3)]) + verify_module(module_sub, input_shapes) -def test_transpose_op_2d(): +@pytest.mark.parametrize("input_shapes", [[(3, 3)], [(3, 3, 3)]]) +def test_transpose_op(input_shapes): def module_transpose(a): return jnp.transpose(a) - verify_module(module_transpose, [(3, 3)]) + verify_module(module_transpose, input_shapes) -@pytest.mark.skip( - "Scalars currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73" -) def test_scalar_type(): def module_scalar_type(a): return a.shape[0] @@ -233,15 +277,6 @@ def module_scalar_type(a): verify_module(module_scalar_type, [(3, 3)]) -# Transpose op failing for higher ranks/dimensions. -@pytest.mark.skip("Transpose op failing for higher ranks/dimensions.") -def test_transpose_op_3d(): - def module_transpose(a): - return jnp.transpose(a) - - verify_module(module_transpose, [(3, 3, 3)]) - - dim0_cases = [] for begin in numpy.arange(10).tolist(): for end in numpy.arange(90, 100).tolist(): @@ -264,9 +299,8 @@ def module_transpose(a): @pytest.mark.parametrize( - "begin, end, dim", [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases] + ["begin", "end", "dim"], [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases] ) -@pytest.mark.skip("Requires tt-metal uplift.") def test_slice(begin, end, dim): def module_slice(a): if dim == 0: diff --git a/tests/TTIR/test_conv2d.py b/tests/TTIR/test_conv2d.py index a03704a..12dbbc3 100644 --- a/tests/TTIR/test_conv2d.py +++ b/tests/TTIR/test_conv2d.py @@ -10,8 +10,19 @@ @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, padding", - ( + [ + "batch_size", + "output_channels", + "input_channels", + "input_height", + "input_width", + "filter_height", + "filter_width", + "stride_h", + "stride_w", + "padding", + ], + [ # RESNET (1, 64, 3, 224, 224, 7, 7, 2, 2, 3), (1, 256, 64, 56, 56, 1, 1, 1, 1, 0), @@ -23,23 +34,26 @@ (1, 128, 128, 28, 28, 3, 3, 1, 1, 1), (1, 512, 128, 28, 28, 1, 1, 1, 1, 0), (1, 128, 512, 28, 28, 1, 1, 1, 1, 0), - # (1, 1024, 512, 28, 28, 1, 1, 2, 2, 0), Requires block sharding + (1, 1024, 512, 28, 28, 1, 1, 2, 2, 0), (1, 256, 512, 28, 28, 1, 1, 2, 2, 0), (1, 256, 256, 14, 14, 3, 3, 1, 1, 1), (1, 1024, 256, 14, 14, 1, 1, 1, 1, 0), (1, 256, 1024, 14, 14, 1, 1, 1, 1, 0), - # (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding - # (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), Requires block sharding - # (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), Requires block sharding - # (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0), Requires block sharding - # (1, 512, 2048, 7, 7, 1, 1, 1, 1, 0), Requires block sharding + (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), + (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), + (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), + (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0), + pytest.param( + *(1, 512, 2048, 7, 7, 1, 1, 1, 1, 0), + marks=pytest.mark.skip(reason="PCC is 0.8828 which is less than 0.95"), + ), # MISCELLANEOUS (1, 64, 16, 115, 115, 4, 4, 1, 1, 0), (1, 64, 64, 8, 8, 3, 3, 1, 1, 1), (1, 64, 64, 16, 16, 3, 3, 1, 1, 1), (1, 256, 256, 7, 7, 3, 3, 1, 1, 1), (1, 256, 64, 56, 56, 1, 1, 2, 2, 0), - ), + ], ) def test_conv2d( batch_size, diff --git a/tests/TTIR/test_maxpool2d.py b/tests/TTIR/test_maxpool2d.py index ba2ebce..95e38dc 100644 --- a/tests/TTIR/test_maxpool2d.py +++ b/tests/TTIR/test_maxpool2d.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -import jax import jax.numpy as jnp import flax @@ -42,9 +41,7 @@ (1, 128, 128, 128), ], ) -def test_maxpool2d( - act_shape, -): +def test_maxpool2d(act_shape): def module_maxpool(img): return flax.linen.max_pool( img, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0)) @@ -59,7 +56,8 @@ def module_maxpool(img): ) -def test_resnet_maxpool2d(): +@pytest.mark.parametrize("act_shape", [(1, 112, 112, 64)]) +def test_resnet_maxpool2d(act_shape): def module_resnet_maxpool(x): x = flax.linen.max_pool( x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)) @@ -68,7 +66,7 @@ def module_resnet_maxpool(x): verify_module( module_resnet_maxpool, - [(1, 112, 112, 64)], + [act_shape], required_pcc=0.95, required_atol=float("inf"), dtype=jnp.bfloat16, diff --git a/tests/TTIR/test_mnist.py b/tests/TTIR/test_mnist.py index 64f862f..f87fc54 100644 --- a/tests/TTIR/test_mnist.py +++ b/tests/TTIR/test_mnist.py @@ -9,48 +9,58 @@ from infrastructure import verify_module -def test_matmul(): +@pytest.mark.parametrize("input_shapes", [[(32, 32), (32, 32)]]) +def test_matmul(input_shapes): def module_matmul(a, b): return jnp.matmul(a, b) - verify_module(module_matmul, [(32, 32), (32, 32)], required_atol=3e-2) + verify_module(module_matmul, input_shapes, required_atol=3e-2) -def test_matmul_with_bias(): +@pytest.mark.parametrize("input_shapes", [[(32, 32), (32, 32), (1, 32)]]) +def test_matmul_with_bias(input_shapes): def module_matmul(a, b, bias): return jnp.matmul(a, b) + bias - verify_module(module_matmul, [(32, 32), (32, 32), (1, 32)], required_atol=3e-2) + verify_module(module_matmul, input_shapes, required_atol=3e-2) -def test_relu_no_broadcast(): +@pytest.mark.parametrize("input_shapes", [[(32, 32), (32, 32)]]) +def test_relu_no_broadcast(input_shapes): def module_relu(a, b): return jnp.maximum(a, b) - verify_module(module_relu, [(32, 32), (32, 32)]) + verify_module(module_relu, input_shapes) -def test_relu(): - pytest.skip("Asserts") - +@pytest.mark.parametrize("input_shapes", [[(32, 32)]]) +@pytest.mark.skip( + "ttnn::operations::binary::BinaryDeviceOperation: unsupported broadcast" +) +def test_relu(input_shapes): def module_relu(a): return jnp.maximum(a, 0) - verify_module(module_relu, [(32, 32)]) + verify_module(module_relu, input_shapes) -@pytest.mark.skip("keepdims=False in runtime") -def test_softmax(): +@pytest.mark.parametrize("input_shapes", [[(32, 32)]]) +@pytest.mark.skip("keepdim=False is not supported") +def test_softmax(input_shapes): def module_softmax(a): return jax.nn.softmax(a) - verify_module(module_softmax, [(32, 32)]) + verify_module(module_softmax, input_shapes) +@pytest.mark.parametrize( + ["act", "w0", "b0", "w1", "b1", "w2", "b2"], + [[(32, 784), (784, 128), (1, 128), (128, 128), (1, 128), (128, 10), (1, 10)]], +) @pytest.mark.skip( - "Index is out of bounds for the rank, should be between 0 and 0 however is 18446744073709551615" + "ttnn::operations::binary::BinaryDeviceOperation: unsupported broadcast" ) -def test_mnist(): +def test_mnist(act, w0, b0, w1, b1, w2, b2): def module_mnist(act, w0, b0, w1, b1, w2, b2): x = jnp.matmul(act, w0) + b0 x = jnp.maximum(x, 0) @@ -62,5 +72,5 @@ def module_mnist(act, w0, b0, w1, b1, w2, b2): verify_module( module_mnist, - [(32, 784), (784, 128), (1, 128), (128, 128), (1, 128), (128, 10), (1, 10)], + [act, w0, b0, w1, b1, w2, b2], ) diff --git a/tests/TTIR/test_reshape.py b/tests/TTIR/test_reshape.py index 898e3aa..dde28b1 100644 --- a/tests/TTIR/test_reshape.py +++ b/tests/TTIR/test_reshape.py @@ -3,15 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -import jax import jax.numpy as jnp -import flax from infrastructure import verify_module @pytest.mark.parametrize( - "source_and_target_shape", + ["act_shape", "target_shape"], [ ((8, 32, 256), (2, 4, 32, 256)), ((8, 32, 32), (1, 2, 4, 32, 32)), @@ -19,9 +17,7 @@ ], ids=["1", "2", "3"], ) -def test_reshape(source_and_target_shape): - act_shape, target_shape = source_and_target_shape - +def test_reshape(act_shape, target_shape): def module_reshape(act): return jnp.reshape(act, target_shape) diff --git a/tests/TTIR/test_scalar_add.py b/tests/TTIR/test_scalar_add.py deleted file mode 100644 index 61c3a51..0000000 --- a/tests/TTIR/test_scalar_add.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import jax -import jax.numpy as jnp - - -def test_scalar_add(): - pytest.skip("Not working") - - def module_add(a, b): - return a + b - - a = jnp.float32(5.0) - b = jnp.array(6.0) - tt_graph = jax.jit(module_add) - res = tt_graph(a, b) - cpu_graph = jax.jit(module_add, backend="cpu") - res_cpu = cpu_graph(a, b) - assert jnp.allclose(res, res_cpu) diff --git a/tests/TTIR/test_simple_regression.py b/tests/TTIR/test_simple_regression.py index 2682dee..582f422 100644 --- a/tests/TTIR/test_simple_regression.py +++ b/tests/TTIR/test_simple_regression.py @@ -4,26 +4,27 @@ import pytest import jax -import jax.numpy as jnp from infrastructure import verify_module -@pytest.mark.skip( - "Module contains function used inside the main function. Cannot compile Flatbuffer." -) -def test_gradient(): +@pytest.mark.parametrize("input_shapes", [[(2, 2)]]) +@pytest.mark.skip("Inputs to eltwise binary must be tilized") +def test_gradient(input_shapes): def simple_gradient(a): def gradient(a): return (a**2).sum() return jax.grad(gradient)(a) - verify_module(simple_gradient, [(2, 2)]) + verify_module(simple_gradient, input_shapes) -@pytest.mark.skip("TT_METAL_HOME is not set.") -def test_simple_regression(): +@pytest.mark.parametrize( + ["weights", "bias", "X", "y"], [[(1, 2), (1, 1), (2, 1), (1, 1)]] +) +@pytest.mark.skip("failed to legalize operation 'stablehlo.dot_general'") +def test_simple_regression(weights, bias, X, y): def simple_regression(weights, bias, X, y): def loss(weights, bias, X, y): predict = X.dot(weights) + bias @@ -34,4 +35,4 @@ def loss(weights, bias, X, y): return weights - verify_module(simple_regression, [(1, 2), (1, 1), (2, 1), (1, 1)]) + verify_module(simple_regression, [weights, bias, X, y])