From 7db95ff0cc7290d2856312b56682c9d6c374f072 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Thu, 2 Jan 2025 22:31:02 +0100 Subject: [PATCH] Fix `backward` and `mtl_backward` bug with some tensor shapes (#227) * Add failing parametrizations to test_value_is_correct * Fix bug in _get_jac_matrix_chunk * Add [YANKED] tag next to [0.4.0] header * Add changelog entry --- CHANGELOG.md | 7 ++++++- src/torchjd/autojac/_transform/jac.py | 2 +- tests/unit/autojac/test_backward.py | 2 +- tests/unit/autojac/test_mtl_backward.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 990f5d9..70ff2a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,12 @@ changes that do not affect the user. ## [Unreleased] -## [0.4.0] - 2025-01-02 +### Fixed + +- Fixed a bug introduced in v0.4.0 that could cause `backward` and `mtl_backward` to fail with some + tensor shapes. + +## [0.4.0] - 2025-01-02 [YANKED] ### Changed diff --git a/src/torchjd/autojac/_transform/jac.py b/src/torchjd/autojac/_transform/jac.py index 435fe55..8279c87 100644 --- a/src/torchjd/autojac/_transform/jac.py +++ b/src/torchjd/autojac/_transform/jac.py @@ -114,7 +114,7 @@ def _get_jac_matrix_chunk( chunk_size = jac_outputs_chunk[0].shape[0] if chunk_size == 1: - grad_outputs = [tensor.squeeze() for tensor in jac_outputs_chunk] + grad_outputs = [tensor.squeeze(0) for tensor in jac_outputs_chunk] gradient_vector = get_vjp(grad_outputs) return gradient_vector.unsqueeze(0) else: diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 0eb5486..0432d75 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -24,7 +24,7 @@ def test_various_aggregators(aggregator: Aggregator): @mark.parametrize("aggregator", [Mean(), UPGrad()]) -@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)]) +@mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)]) @mark.parametrize("manually_specify_inputs", [True, False]) @mark.parametrize("chunk_size", [1, 2, None]) def test_value_is_correct( diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index f16edb2..3810e50 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -27,7 +27,7 @@ def test_various_aggregators(aggregator: Aggregator): @mark.parametrize("aggregator", [Mean(), UPGrad()]) -@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)]) +@mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)]) @mark.parametrize("manually_specify_shared_params", [True, False]) @mark.parametrize("manually_specify_tasks_params", [True, False]) @mark.parametrize("chunk_size", [1, 2, None])