diff --git a/tests/unit/losses/providers/torch/test_segmentation.py b/tests/unit/losses/providers/torch/test_segmentation.py index f996a1b0..31e194e8 100644 --- a/tests/unit/losses/providers/torch/test_segmentation.py +++ b/tests/unit/losses/providers/torch/test_segmentation.py @@ -66,7 +66,7 @@ def test_iou_batch(): def test__threshold(): - x = torch.rand(2, 3) + x = torch.rand(10, 10) out = _threshold(x, threshold=0.5) assert torch.max(out) == 1 assert torch.min(out) == 0