From b6116109e3244316c6dd8df38d9a4941c1833b52 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Wed, 23 Oct 2024 14:31:42 -0400 Subject: [PATCH] Fix handling of event dimensions in `ComposeTransform` (fixes #1893). (#1894) --- numpyro/distributions/transforms.py | 2 +- test/test_transforms.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 290e504c1..05ae19efd 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -282,7 +282,7 @@ def __eq__(self, other): def _get_compose_transform_input_event_dim(parts): input_event_dim = parts[-1].domain.event_dim - for part in parts[len(parts) - 1 :: -1]: + for part in parts[:-1][::-1]: input_event_dim = part.domain.event_dim + max( input_event_dim - part.codomain.event_dim, 0 ) diff --git a/test/test_transforms.py b/test/test_transforms.py index cecbf2ca5..9d8c65946 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -397,3 +397,16 @@ def test_biject_to(constraint, shape): expected_shape = constrained.shape[: constrained.ndim - constraint.event_dim] assert passed.shape == expected_shape assert jnp.all(passed) + + +@pytest.mark.parametrize( + "transform", + [ + CorrCholeskyTransform(), + CorrCholeskyTransform().inv, + ], +) +def test_compose_domain_codomain(transform): + composed = ComposeTransform([transform]) + assert transform.domain.event_dim == composed.domain.event_dim + assert transform.codomain.event_dim == composed.codomain.event_dim