Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rollout aggregation function #254

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions inseq/attr/feat/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
from .reagent import Reagent
from .rollout import rollout_fn
from .sequential_integrated_gradients import SequentialIntegratedGradients
from .value_zeroing import ValueZeroing

Expand All @@ -12,4 +13,5 @@
"Lime",
"Reagent",
"SequentialIntegratedGradients",
"rollout_fn",
]
180 changes: 180 additions & 0 deletions inseq/attr/feat/ops/rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Union

import torch
import torch.nn.functional as F

from ....utils import normalize as normalize_fn
from ....utils.typing import (
MultiUnitScoreTensor,
ScoreTensor,
)


def _check_matrix_shape(
scores: Union[MultiUnitScoreTensor, tuple[MultiUnitScoreTensor, MultiUnitScoreTensor, MultiUnitScoreTensor]],
) -> None:
"""Checks that the shape of the provided scores is compatible with the rollout aggregation method."""

def fix_target_scores(target_scores: MultiUnitScoreTensor) -> MultiUnitScoreTensor:
has_prefix_target = False
if target_scores.size(-2) - target_scores.size(-1) == 1:
target_scores = torch.cat([torch.zeros_like(target_scores[..., -1])[..., None], target_scores], dim=-1)
target_scores[..., 0, 0] = 1.0
has_prefix_target = True
if target_scores.size(-1) != target_scores.size(-2):
raise ValueError(
"Expected scores to be a tensor of shape (T, T) but got shape "
f"{target_scores.size(-2), target_scores.size(-1)}. {msg}"
)
target_scores[target_scores.isnan()] = 0.0
return target_scores, has_prefix_target

msg = (
"This can be due to a non-zero starting index used in generation, which is not supported by the rollout "
"aggregation method. Use attribute_full_target=True in model.attribute to attribute the full target sequence."
)
if isinstance(scores, tuple):
source_scores, cross_scores, target_scores = scores
dim0, dim1 = -2, -1
source_dim = source_scores.size(dim0)
target_dim = target_scores.size(dim0)
try:
assert source_scores.size(dim1) == source_dim # source scores S x S
assert cross_scores.size(dim0) == source_dim and cross_scores.size(dim1) == target_dim # x-scores S x T
assert target_scores.size(dim1) == target_dim # target scores T x T
except AssertionError as e:
raise ValueError(
"Expected scores to be a tuple of tensors of shape (S, S), (S, T), (T, T) but got shapes "
f"{source_dim, source_scores.size(dim1)}, {cross_scores.size(dim0), cross_scores.size(dim1)}, "
f"{target_dim, target_scores.size(dim1)}. {msg}"
) from e
target_scores, has_prefix_target = fix_target_scores(target_scores)
return (source_scores, cross_scores, target_scores), has_prefix_target
else:
return fix_target_scores(scores)


def _rollout_single(
scores: MultiUnitScoreTensor,
normalize: bool = False,
) -> MultiUnitScoreTensor:
"""Performs rollout aggregation by `Abnar and Zuidema (2020) <https://aclanthology.org/2020.acl-main.385/>`__
This is a helper function used in :func:`~inseq.attr.feat.ops.rollout` to rollout a single layer stack.
"""
rollout_scores = torch.zeros_like(scores)
rollout_scores[:, 0, ...] = scores[:, 0, ...]
for i in range(1, scores.size(1)):
# Rollout scores at layer i by matmul them with the scores at layer i-1
layer_rollout_scores = scores[:, i, ...] @ rollout_scores[:, i - 1, ...]
if normalize:
rollout_scores[:, i, ...] = F.normalize(layer_rollout_scores, p=1, dim=-1)
else:
rollout_scores[:, i, ...] = layer_rollout_scores
return rollout_scores


def _rollout_joint(
final_source_scores: ScoreTensor,
cross_scores: MultiUnitScoreTensor,
target_scores: MultiUnitScoreTensor,
) -> tuple[ScoreTensor, ScoreTensor]:
"""Performs the rollout aggregation adapted for an encoder-decoder architecture with cross-importance scores."""
target_scores = (target_scores.mT * cross_scores[..., -1, :]).mT
joint_source_cross_scores = torch.einsum("bl...ij, b...jk -> bl...ik", cross_scores, final_source_scores)
source_rollout_scores = torch.zeros_like(joint_source_cross_scores)
source_rollout_scores[:, 0, ...] = joint_source_cross_scores[:, 0, ...]
target_rollout_scores = torch.zeros_like(target_scores)
target_rollout_scores[:, 0, ...] = target_scores[:, 0, ...]
for i in range(1, target_scores.size(1)):
# Target scores x previous cross rollout scores
source_rollout_scores[:, i, ...] = (
target_scores[:, i, ...] @ source_rollout_scores[:, i - 1, ...]
) + joint_source_cross_scores[:, i, ...]
# Target scores x previous target rollout scores
target_rollout_scores[:, i, ...] = target_scores[:, i, ...] @ target_rollout_scores[:, i - 1, ...]
# Normalize scores across source and target
source_rollout_scores, target_rollout_scores = normalize_fn(
(source_rollout_scores, target_rollout_scores), cat_dim=-1
)
return source_rollout_scores, target_rollout_scores


def rollout_fn(
scores: Union[MultiUnitScoreTensor, tuple[MultiUnitScoreTensor, MultiUnitScoreTensor, MultiUnitScoreTensor]],
dim: int = 1,
) -> Union[ScoreTensor, tuple[ScoreTensor, ScoreTensor]]:
"""Reference implementations:
* `samiraabnar/attention-flow
<https://github.com/samiraabnar/attention_flow/blob/master/attention_graph_util.py#L104>`__
* `mt-upc/transformer-contributions-nmt
<https://github.com/mt-upc/transformer-contributions-nmt/blob/main/wrappers/transformer_wrapper.py#L506>`__.

Args:
scores (:obj:`torch.Tensor` or :obj:`tuple(torch.Tensor, torch.Tensor, torch.Tensor)`):
Tensor of shape `(num_layers, ...)`, or a tuple of tensors of the same shape containing the
scores computed for different layers. If a tuple is passed, rollout will be performed assuming tensors are
(source_scores, cross_scores, target_scores) produced by an Transformer-like encoder-decoder architecture
(i.e. rolled-out importance of the source in the encoder is modulated by cross_scores at every layer of the
decoder). For an encoder-decoder architecture, the rollout procedure follows the procedure described by
`Ferrando et al. (2022) <https://aclanthology.org/2022.emnlp-main.599/>`__.
dim (:obj:`int`, `optional`, defaults to 1): The dimension along which to perform the rollout aggregation.

Returns:
:obj:`torch.Tensor` or :obj:`tuple(torch.Tensor, torch.Tensor)`:
An aggregated score tensor of shape `(batch_size, ...)`, or a tuple of tensors of the same shape containing
the scores aggregated using rollout until the topmost provided layer (e.g. for ``layers=[1,2,4]`` the
rollout is done skipping layer 3, and only rolled out scores at layer 4 are returned). If encoder-decoder
rollout is performed, a tuple of tensors ``(source_scores, target_scores)``.
"""
squeeze_batch_dim = False
if isinstance(scores, tuple):
if dim < 0:
dim = scores[0].ndim + dim
if scores[0].ndim < 4:
scores = tuple(t.unsqueeze(0) for t in scores)
squeeze_batch_dim = True
dim += 1
if dim != 1:
swap_dim = dim if dim < 2 else dim + 1
scores = tuple(s[:, None, ...].transpose(swap_dim, 1).squeeze(swap_dim) for s in scores)
scores, has_target_prefix = _check_matrix_shape(scores)
source_scores, cross_scores, target_scores = scores

# Get rolled out scores of encoder last layer with respect to source input
final_source_scores = _rollout_single(source_scores.mT)[:, -1, ...].mT

source_rollout_scores, target_rollout_scores = _rollout_joint(final_source_scores, cross_scores, target_scores)
if has_target_prefix:
target_rollout_scores = target_rollout_scores[..., 1:]
source_rollout_scores = source_rollout_scores[:, -1, ...].unsqueeze(1)
target_rollout_scores = target_rollout_scores[:, -1, ...].unsqueeze(1)
if dim != 1:
source_rollout_scores = source_rollout_scores.transpose(1, dim)
target_rollout_scores = target_rollout_scores.transpose(1, dim)
source_rollout_scores = source_rollout_scores.squeeze(dim)
target_rollout_scores = target_rollout_scores.squeeze(dim)
if squeeze_batch_dim:
source_rollout_scores = source_rollout_scores.squeeze(0)
target_rollout_scores = target_rollout_scores.squeeze(0)
return source_rollout_scores, target_rollout_scores
else:
# Convert rollout dim to positive index to account for new dim insertions.
if dim < 0:
dim = scores.ndim + dim
# Add batch dimension if not present. Assumed shape (batch_size, ...) with num_layers at position dim and at
# least two dimensions representing scores that will be rolled out.
if scores.ndim < 4:
scores = scores[None, ...]
squeeze_batch_dim = True
dim += 1
if dim != 1:
scores = scores[:, None, ...]
swap_dim = dim if dim < 2 else dim + 1
scores = scores.transpose(swap_dim, 1).squeeze(swap_dim)
scores, has_target_prefix = _check_matrix_shape(scores)
target_rollout_scores = _rollout_single(scores.mT)[:, -1, ...].mT
if has_target_prefix:
target_rollout_scores = target_rollout_scores[..., 1:]
if squeeze_batch_dim:
target_rollout_scores = target_rollout_scores.squeeze(0)
return target_rollout_scores
40 changes: 40 additions & 0 deletions inseq/data/aggregation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from torch.linalg import vector_norm

from ..attr.feat.ops import rollout_fn
from ..utils import Registry, available_classes
from ..utils.typing import (
ScoreTensor,
Expand Down Expand Up @@ -93,6 +94,45 @@ def __call__(self, scores: torch.Tensor, dim: int, vnorm_ord: int = 2) -> ScoreT
return vector_norm(scores, ord=vnorm_ord, dim=dim)


class RolloutAggregationFunction(AggregationFunction):
aggregation_function_name = "rollout"

def __init__(self):
super().__init__()
self.takes_single_tensor: bool = False
self.takes_sequence_scores: bool = True

def __call__(
self,
scores: Union[torch.Tensor, tuple[torch.Tensor, ...]],
dim: int,
sequence_scores: dict[str, torch.Tensor] = {},
) -> ScoreTensor:
dec_self_prefix = "decoder_self"
enc_self_prefix = "encoder_self"
dec_match = [name for name in sequence_scores.keys() if name.startswith(dec_self_prefix)]
enc_match = [name for name in sequence_scores.keys() if name.startswith(enc_self_prefix)]
if isinstance(scores, torch.Tensor):
# If no matching prefix is found, we assume the decoder-only target-only rollout case
if not dec_match or not enc_match:
return rollout_fn(scores, dim=dim)
# If both prefixes are found, we assume the encoder-decoder source-only rollout case
else:
enc_match = sequence_scores[enc_match[0]]
dec_match = sequence_scores[dec_match[0]]
return rollout_fn((enc_match, scores, dec_match), dim=dim)[0]
elif not enc_match:
raise KeyError(
"Could not find encoder self-importance scores in sequence scores. "
"Encoder self-importance scores are required for encoder-decoder rollout. They should be provided "
f"as an entry in the sequence scores dictionary with key starting with '{enc_self_prefix}', and "
"value being a tensor of shape (src_seq_len, src_seq_len, ..., rollout_dim)."
)
else:
enc_match = sequence_scores[enc_match[0]]
return rollout_fn((enc_match,) + scores, dim=dim)


DEFAULT_ATTRIBUTION_AGGREGATE_DICT = {
"source_attributions": {"spans": "absmax"},
"target_attributions": {"spans": "absmax"},
Expand Down
24 changes: 24 additions & 0 deletions tests/attr/feat/ops/test_rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch

from inseq.attr.feat.ops.rollout import rollout_fn


def test_rollout_consistency_custom_dim():
original = torch.randn(1, 2, 3, 4, 5, 6, 2, 2)
rolled_out_original = rollout_fn(original)
dim_changed_2 = original.permute(0, 2, 1, 3, 4, 5, 6, 7)
rolled_out_dim_changed_2 = rollout_fn(dim_changed_2, dim=2)
assert rolled_out_original.shape == rolled_out_dim_changed_2.shape
assert torch.allclose(rolled_out_original, rolled_out_dim_changed_2)
dim_changed_3 = original.permute(0, 2, 3, 1, 4, 5, 6, 7)
rolled_out_dim_changed_3 = rollout_fn(dim_changed_3, dim=3)
assert rolled_out_original.shape == rolled_out_dim_changed_3.shape
assert torch.allclose(rolled_out_original, rolled_out_dim_changed_3)
dim_changed_4 = original.permute(0, 2, 3, 4, 1, 5, 6, 7)
rolled_out_dim_changed_4 = rollout_fn(dim_changed_4, dim=4)
assert rolled_out_original.shape == rolled_out_dim_changed_4.shape
assert torch.allclose(rolled_out_original, rolled_out_dim_changed_4)
dim_changed_5 = original.permute(0, 2, 3, 4, 5, 1, 6, 7)
rolled_out_dim_changed_5 = rollout_fn(dim_changed_5, dim=5)
assert rolled_out_original.shape == rolled_out_dim_changed_5.shape
assert torch.allclose(rolled_out_original, rolled_out_dim_changed_5)
Loading