Skip to content

Commit

Permalink
Fix backward and mtl_backward bug with some tensor shapes (#227)
Browse files Browse the repository at this point in the history
* Add failing parametrizations to test_value_is_correct
* Fix bug in _get_jac_matrix_chunk
* Add [YANKED] tag next to [0.4.0] header
* Add changelog entry
  • Loading branch information
ValerianRey authored Jan 2, 2025
1 parent 1a8454e commit 7db95ff
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ changes that do not affect the user.

## [Unreleased]

## [0.4.0] - 2025-01-02
### Fixed

- Fixed a bug introduced in v0.4.0 that could cause `backward` and `mtl_backward` to fail with some
tensor shapes.

## [0.4.0] - 2025-01-02 [YANKED]

### Changed

Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _get_jac_matrix_chunk(

chunk_size = jac_outputs_chunk[0].shape[0]
if chunk_size == 1:
grad_outputs = [tensor.squeeze() for tensor in jac_outputs_chunk]
grad_outputs = [tensor.squeeze(0) for tensor in jac_outputs_chunk]
gradient_vector = get_vjp(grad_outputs)
return gradient_vector.unsqueeze(0)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/autojac/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_various_aggregators(aggregator: Aggregator):


@mark.parametrize("aggregator", [Mean(), UPGrad()])
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)])
@mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)])
@mark.parametrize("manually_specify_inputs", [True, False])
@mark.parametrize("chunk_size", [1, 2, None])
def test_value_is_correct(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/autojac/test_mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_various_aggregators(aggregator: Aggregator):


@mark.parametrize("aggregator", [Mean(), UPGrad()])
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)])
@mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)])
@mark.parametrize("manually_specify_shared_params", [True, False])
@mark.parametrize("manually_specify_tasks_params", [True, False])
@mark.parametrize("chunk_size", [1, 2, None])
Expand Down

0 comments on commit 7db95ff

Please sign in to comment.