diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py
index a40b9dba..7a4d726f 100644
--- a/inseq/attr/feat/ops/__init__.py
+++ b/inseq/attr/feat/ops/__init__.py
@@ -2,6 +2,7 @@
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
from .reagent import Reagent
+from .rollout import rollout_fn
from .sequential_integrated_gradients import SequentialIntegratedGradients
from .value_zeroing import ValueZeroing
@@ -12,4 +13,5 @@
"Lime",
"Reagent",
"SequentialIntegratedGradients",
+ "rollout_fn",
]
diff --git a/inseq/attr/feat/ops/rollout.py b/inseq/attr/feat/ops/rollout.py
new file mode 100644
index 00000000..2d8b13a3
--- /dev/null
+++ b/inseq/attr/feat/ops/rollout.py
@@ -0,0 +1,180 @@
+from typing import Union
+
+import torch
+import torch.nn.functional as F
+
+from ....utils import normalize as normalize_fn
+from ....utils.typing import (
+ MultiUnitScoreTensor,
+ ScoreTensor,
+)
+
+
+def _check_matrix_shape(
+ scores: Union[MultiUnitScoreTensor, tuple[MultiUnitScoreTensor, MultiUnitScoreTensor, MultiUnitScoreTensor]],
+) -> None:
+ """Checks that the shape of the provided scores is compatible with the rollout aggregation method."""
+
+ def fix_target_scores(target_scores: MultiUnitScoreTensor) -> MultiUnitScoreTensor:
+ has_prefix_target = False
+ if target_scores.size(-2) - target_scores.size(-1) == 1:
+ target_scores = torch.cat([torch.zeros_like(target_scores[..., -1])[..., None], target_scores], dim=-1)
+ target_scores[..., 0, 0] = 1.0
+ has_prefix_target = True
+ if target_scores.size(-1) != target_scores.size(-2):
+ raise ValueError(
+ "Expected scores to be a tensor of shape (T, T) but got shape "
+ f"{target_scores.size(-2), target_scores.size(-1)}. {msg}"
+ )
+ target_scores[target_scores.isnan()] = 0.0
+ return target_scores, has_prefix_target
+
+ msg = (
+ "This can be due to a non-zero starting index used in generation, which is not supported by the rollout "
+ "aggregation method. Use attribute_full_target=True in model.attribute to attribute the full target sequence."
+ )
+ if isinstance(scores, tuple):
+ source_scores, cross_scores, target_scores = scores
+ dim0, dim1 = -2, -1
+ source_dim = source_scores.size(dim0)
+ target_dim = target_scores.size(dim0)
+ try:
+ assert source_scores.size(dim1) == source_dim # source scores S x S
+ assert cross_scores.size(dim0) == source_dim and cross_scores.size(dim1) == target_dim # x-scores S x T
+ assert target_scores.size(dim1) == target_dim # target scores T x T
+ except AssertionError as e:
+ raise ValueError(
+ "Expected scores to be a tuple of tensors of shape (S, S), (S, T), (T, T) but got shapes "
+ f"{source_dim, source_scores.size(dim1)}, {cross_scores.size(dim0), cross_scores.size(dim1)}, "
+ f"{target_dim, target_scores.size(dim1)}. {msg}"
+ ) from e
+ target_scores, has_prefix_target = fix_target_scores(target_scores)
+ return (source_scores, cross_scores, target_scores), has_prefix_target
+ else:
+ return fix_target_scores(scores)
+
+
+def _rollout_single(
+ scores: MultiUnitScoreTensor,
+ normalize: bool = False,
+) -> MultiUnitScoreTensor:
+ """Performs rollout aggregation by `Abnar and Zuidema (2020) `__
+ This is a helper function used in :func:`~inseq.attr.feat.ops.rollout` to rollout a single layer stack.
+ """
+ rollout_scores = torch.zeros_like(scores)
+ rollout_scores[:, 0, ...] = scores[:, 0, ...]
+ for i in range(1, scores.size(1)):
+ # Rollout scores at layer i by matmul them with the scores at layer i-1
+ layer_rollout_scores = scores[:, i, ...] @ rollout_scores[:, i - 1, ...]
+ if normalize:
+ rollout_scores[:, i, ...] = F.normalize(layer_rollout_scores, p=1, dim=-1)
+ else:
+ rollout_scores[:, i, ...] = layer_rollout_scores
+ return rollout_scores
+
+
+def _rollout_joint(
+ final_source_scores: ScoreTensor,
+ cross_scores: MultiUnitScoreTensor,
+ target_scores: MultiUnitScoreTensor,
+) -> tuple[ScoreTensor, ScoreTensor]:
+ """Performs the rollout aggregation adapted for an encoder-decoder architecture with cross-importance scores."""
+ target_scores = (target_scores.mT * cross_scores[..., -1, :]).mT
+ joint_source_cross_scores = torch.einsum("bl...ij, b...jk -> bl...ik", cross_scores, final_source_scores)
+ source_rollout_scores = torch.zeros_like(joint_source_cross_scores)
+ source_rollout_scores[:, 0, ...] = joint_source_cross_scores[:, 0, ...]
+ target_rollout_scores = torch.zeros_like(target_scores)
+ target_rollout_scores[:, 0, ...] = target_scores[:, 0, ...]
+ for i in range(1, target_scores.size(1)):
+ # Target scores x previous cross rollout scores
+ source_rollout_scores[:, i, ...] = (
+ target_scores[:, i, ...] @ source_rollout_scores[:, i - 1, ...]
+ ) + joint_source_cross_scores[:, i, ...]
+ # Target scores x previous target rollout scores
+ target_rollout_scores[:, i, ...] = target_scores[:, i, ...] @ target_rollout_scores[:, i - 1, ...]
+ # Normalize scores across source and target
+ source_rollout_scores, target_rollout_scores = normalize_fn(
+ (source_rollout_scores, target_rollout_scores), cat_dim=-1
+ )
+ return source_rollout_scores, target_rollout_scores
+
+
+def rollout_fn(
+ scores: Union[MultiUnitScoreTensor, tuple[MultiUnitScoreTensor, MultiUnitScoreTensor, MultiUnitScoreTensor]],
+ dim: int = 1,
+) -> Union[ScoreTensor, tuple[ScoreTensor, ScoreTensor]]:
+ """Reference implementations:
+ * `samiraabnar/attention-flow
+ `__
+ * `mt-upc/transformer-contributions-nmt
+ `__.
+
+ Args:
+ scores (:obj:`torch.Tensor` or :obj:`tuple(torch.Tensor, torch.Tensor, torch.Tensor)`):
+ Tensor of shape `(num_layers, ...)`, or a tuple of tensors of the same shape containing the
+ scores computed for different layers. If a tuple is passed, rollout will be performed assuming tensors are
+ (source_scores, cross_scores, target_scores) produced by an Transformer-like encoder-decoder architecture
+ (i.e. rolled-out importance of the source in the encoder is modulated by cross_scores at every layer of the
+ decoder). For an encoder-decoder architecture, the rollout procedure follows the procedure described by
+ `Ferrando et al. (2022) `__.
+ dim (:obj:`int`, `optional`, defaults to 1): The dimension along which to perform the rollout aggregation.
+
+ Returns:
+ :obj:`torch.Tensor` or :obj:`tuple(torch.Tensor, torch.Tensor)`:
+ An aggregated score tensor of shape `(batch_size, ...)`, or a tuple of tensors of the same shape containing
+ the scores aggregated using rollout until the topmost provided layer (e.g. for ``layers=[1,2,4]`` the
+ rollout is done skipping layer 3, and only rolled out scores at layer 4 are returned). If encoder-decoder
+ rollout is performed, a tuple of tensors ``(source_scores, target_scores)``.
+ """
+ squeeze_batch_dim = False
+ if isinstance(scores, tuple):
+ if dim < 0:
+ dim = scores[0].ndim + dim
+ if scores[0].ndim < 4:
+ scores = tuple(t.unsqueeze(0) for t in scores)
+ squeeze_batch_dim = True
+ dim += 1
+ if dim != 1:
+ swap_dim = dim if dim < 2 else dim + 1
+ scores = tuple(s[:, None, ...].transpose(swap_dim, 1).squeeze(swap_dim) for s in scores)
+ scores, has_target_prefix = _check_matrix_shape(scores)
+ source_scores, cross_scores, target_scores = scores
+
+ # Get rolled out scores of encoder last layer with respect to source input
+ final_source_scores = _rollout_single(source_scores.mT)[:, -1, ...].mT
+
+ source_rollout_scores, target_rollout_scores = _rollout_joint(final_source_scores, cross_scores, target_scores)
+ if has_target_prefix:
+ target_rollout_scores = target_rollout_scores[..., 1:]
+ source_rollout_scores = source_rollout_scores[:, -1, ...].unsqueeze(1)
+ target_rollout_scores = target_rollout_scores[:, -1, ...].unsqueeze(1)
+ if dim != 1:
+ source_rollout_scores = source_rollout_scores.transpose(1, dim)
+ target_rollout_scores = target_rollout_scores.transpose(1, dim)
+ source_rollout_scores = source_rollout_scores.squeeze(dim)
+ target_rollout_scores = target_rollout_scores.squeeze(dim)
+ if squeeze_batch_dim:
+ source_rollout_scores = source_rollout_scores.squeeze(0)
+ target_rollout_scores = target_rollout_scores.squeeze(0)
+ return source_rollout_scores, target_rollout_scores
+ else:
+ # Convert rollout dim to positive index to account for new dim insertions.
+ if dim < 0:
+ dim = scores.ndim + dim
+ # Add batch dimension if not present. Assumed shape (batch_size, ...) with num_layers at position dim and at
+ # least two dimensions representing scores that will be rolled out.
+ if scores.ndim < 4:
+ scores = scores[None, ...]
+ squeeze_batch_dim = True
+ dim += 1
+ if dim != 1:
+ scores = scores[:, None, ...]
+ swap_dim = dim if dim < 2 else dim + 1
+ scores = scores.transpose(swap_dim, 1).squeeze(swap_dim)
+ scores, has_target_prefix = _check_matrix_shape(scores)
+ target_rollout_scores = _rollout_single(scores.mT)[:, -1, ...].mT
+ if has_target_prefix:
+ target_rollout_scores = target_rollout_scores[..., 1:]
+ if squeeze_batch_dim:
+ target_rollout_scores = target_rollout_scores.squeeze(0)
+ return target_rollout_scores
diff --git a/inseq/data/aggregation_functions.py b/inseq/data/aggregation_functions.py
index ac35dd84..8f7921d4 100644
--- a/inseq/data/aggregation_functions.py
+++ b/inseq/data/aggregation_functions.py
@@ -19,6 +19,7 @@
import torch
from torch.linalg import vector_norm
+from ..attr.feat.ops import rollout_fn
from ..utils import Registry, available_classes
from ..utils.typing import (
ScoreTensor,
@@ -93,6 +94,45 @@ def __call__(self, scores: torch.Tensor, dim: int, vnorm_ord: int = 2) -> ScoreT
return vector_norm(scores, ord=vnorm_ord, dim=dim)
+class RolloutAggregationFunction(AggregationFunction):
+ aggregation_function_name = "rollout"
+
+ def __init__(self):
+ super().__init__()
+ self.takes_single_tensor: bool = False
+ self.takes_sequence_scores: bool = True
+
+ def __call__(
+ self,
+ scores: Union[torch.Tensor, tuple[torch.Tensor, ...]],
+ dim: int,
+ sequence_scores: dict[str, torch.Tensor] = {},
+ ) -> ScoreTensor:
+ dec_self_prefix = "decoder_self"
+ enc_self_prefix = "encoder_self"
+ dec_match = [name for name in sequence_scores.keys() if name.startswith(dec_self_prefix)]
+ enc_match = [name for name in sequence_scores.keys() if name.startswith(enc_self_prefix)]
+ if isinstance(scores, torch.Tensor):
+ # If no matching prefix is found, we assume the decoder-only target-only rollout case
+ if not dec_match or not enc_match:
+ return rollout_fn(scores, dim=dim)
+ # If both prefixes are found, we assume the encoder-decoder source-only rollout case
+ else:
+ enc_match = sequence_scores[enc_match[0]]
+ dec_match = sequence_scores[dec_match[0]]
+ return rollout_fn((enc_match, scores, dec_match), dim=dim)[0]
+ elif not enc_match:
+ raise KeyError(
+ "Could not find encoder self-importance scores in sequence scores. "
+ "Encoder self-importance scores are required for encoder-decoder rollout. They should be provided "
+ f"as an entry in the sequence scores dictionary with key starting with '{enc_self_prefix}', and "
+ "value being a tensor of shape (src_seq_len, src_seq_len, ..., rollout_dim)."
+ )
+ else:
+ enc_match = sequence_scores[enc_match[0]]
+ return rollout_fn((enc_match,) + scores, dim=dim)
+
+
DEFAULT_ATTRIBUTION_AGGREGATE_DICT = {
"source_attributions": {"spans": "absmax"},
"target_attributions": {"spans": "absmax"},
diff --git a/tests/attr/feat/ops/test_rollout.py b/tests/attr/feat/ops/test_rollout.py
new file mode 100644
index 00000000..b38339e2
--- /dev/null
+++ b/tests/attr/feat/ops/test_rollout.py
@@ -0,0 +1,24 @@
+import torch
+
+from inseq.attr.feat.ops.rollout import rollout_fn
+
+
+def test_rollout_consistency_custom_dim():
+ original = torch.randn(1, 2, 3, 4, 5, 6, 2, 2)
+ rolled_out_original = rollout_fn(original)
+ dim_changed_2 = original.permute(0, 2, 1, 3, 4, 5, 6, 7)
+ rolled_out_dim_changed_2 = rollout_fn(dim_changed_2, dim=2)
+ assert rolled_out_original.shape == rolled_out_dim_changed_2.shape
+ assert torch.allclose(rolled_out_original, rolled_out_dim_changed_2)
+ dim_changed_3 = original.permute(0, 2, 3, 1, 4, 5, 6, 7)
+ rolled_out_dim_changed_3 = rollout_fn(dim_changed_3, dim=3)
+ assert rolled_out_original.shape == rolled_out_dim_changed_3.shape
+ assert torch.allclose(rolled_out_original, rolled_out_dim_changed_3)
+ dim_changed_4 = original.permute(0, 2, 3, 4, 1, 5, 6, 7)
+ rolled_out_dim_changed_4 = rollout_fn(dim_changed_4, dim=4)
+ assert rolled_out_original.shape == rolled_out_dim_changed_4.shape
+ assert torch.allclose(rolled_out_original, rolled_out_dim_changed_4)
+ dim_changed_5 = original.permute(0, 2, 3, 4, 5, 1, 6, 7)
+ rolled_out_dim_changed_5 = rollout_fn(dim_changed_5, dim=5)
+ assert rolled_out_original.shape == rolled_out_dim_changed_5.shape
+ assert torch.allclose(rolled_out_original, rolled_out_dim_changed_5)