Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix batching in DIG #114

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f0465b7
Minor fixes, problem persists
gsarti Dec 2, 2021
3bf685e
Polishing model loading with pre-loaded models support
gsarti Oct 20, 2022
edfb5db
Decoupled decoder-only models
gsarti Oct 26, 2022
a384a21
Decoupled decoder-only attribute args formatting, minor typing adjust…
gsarti Oct 27, 2022
b79cd28
Merge remote-tracking branch 'origin/main' into decoder-only-support
gsarti Oct 31, 2022
ba8d041
MPS GPU support
gsarti Nov 1, 2022
dc85873
Fixed GPU backend check, optional deps separation
gsarti Nov 2, 2022
bbe1f5f
Fix safety
gsarti Nov 2, 2022
ac92bd6
Architecture classes, fix partial attribution
gsarti Nov 3, 2022
8acb6cf
Further refactoring of architecture-specific steps
gsarti Nov 4, 2022
ff43fbd
Further refactoring of architecture-specific steps
gsarti Nov 4, 2022
07224fd
Decoder attribution working (buggy)
gsarti Nov 7, 2022
90a47ca
Fix encoder-decoder input
gsarti Nov 8, 2022
21dbe42
Adjustments for decoder-only attribution visualization
gsarti Nov 9, 2022
27d1a77
Decoder-only attribution working
gsarti Nov 10, 2022
9a79f13
Fixed batched constrained decoder-only attribution and single decoder…
gsarti Nov 15, 2022
1961198
Support for missing unk/eos tokens in HF model
gsarti Nov 17, 2022
92bcce3
Support for missing unk/eos tokens in HF model
gsarti Nov 17, 2022
d9bca12
Fixed attr pos for batch decoder-only, update readme
gsarti Nov 18, 2022
4141171
Arch methods reorg, fix seq attr conversion bug for n=1, extra tests
gsarti Nov 28, 2022
8649606
Update deps
gsarti Nov 29, 2022
b301363
Merge decoder-only-support into fix-batch-dig
gsarti Nov 29, 2022
fd53bc4
Fix compat with decoder-only, batch result mismatch
gsarti Nov 30, 2022
8536fd2
Fix style
gsarti Dec 1, 2022
365eb12
Merge remote-tracking branch 'origin/main' into fix-batch-dig
gsarti Dec 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,20 @@ def __init__(self, attribution_model: "AttributionModel", hook_to_model: bool =
**kwargs: Additional keyword arguments to pass to the hook method.

Attributes:
is_layer_attribution (:obj:`bool`, default `False`): If True, the attribution method maps saliency
scores to the output of a layer instead of model inputs. Layer attribution methods do not require
interpretable embeddings unless intermediate features before the embedding layer are attributed.
attribute_batch_ids (:obj:`bool`, default `False`): If True, the attribution method will receive batch ids
instead of batch embeddings for attribution. Used by layer gradient-based attribution methods mapping
saliency scores to the output of a layer instead of model inputs.
forward_batch_embeds (:obj:`bool`, default `True`): If True, the model will use embeddings in the
forward pass instead of token ids. Using this in combination with `attribute_batch_ids` will allow for
custom conversion of ids into embeddings inside the attribution method.
target_layer (:obj:`torch.nn.Module`, default `None`): The layer on which attribution should be
performed if is_layer_attribution is True.
performed for layer attribution methods.
use_baseline (:obj:`bool`, default `False`): Whether a baseline should be used for the attribution method.
"""
super().__init__()
self.attribution_model = attribution_model
self.is_layer_attribution: bool = False
self.attribute_batch_ids: bool = False
self.forward_batch_embeds: bool = True
self.target_layer = None
self.use_baseline: bool = False
if hook_to_model:
Expand Down Expand Up @@ -131,6 +135,9 @@ def load(
"""
from ...models import load_model

methods = cls.available_classes()
if method_name not in methods:
raise UnknownAttributionMethodError(method_name)
if model_name_or_path is not None:
model = load_model(model_name_or_path)
elif attribution_model is not None:
Expand All @@ -140,9 +147,6 @@ def load(
"Only one among an initialized model and a model identifier "
"must be defined when loading the attribution method."
)
methods = cls.available_classes()
if method_name not in methods:
raise UnknownAttributionMethodError(method_name)
return methods[method_name](model, **kwargs)

@batched
Expand Down Expand Up @@ -210,11 +214,7 @@ def prepare_and_attribute(
# We do this here to support separate attr_pos_start for different sentences when batching
if attr_pos_start is None or attr_pos_start < encoded_sources.input_ids.shape[1]:
attr_pos_start = encoded_sources.input_ids.shape[1]
batch = self.attribution_model.prepare_inputs_for_attribution(
inputs,
include_eos_baseline,
self.is_layer_attribution,
)
batch = self.attribution_model.prepare_inputs_for_attribution(inputs, include_eos_baseline)
# If prepare_and_attribute was called from AttributionModel.attribute,
# attributed_fn is already a Callable. Keep here to allow for usage independently
# of AttributionModel.attribute.
Expand Down Expand Up @@ -289,9 +289,9 @@ def attribute(
an optional added list of single :class:`~inseq.data.FeatureAttributionStepOutput` for each step and
extra information regarding the attribution parameters.
"""
if self.is_layer_attribution and attribute_target:
if self.attribute_batch_ids and not self.forward_batch_embeds and attribute_target:
raise ValueError(
"Layer attribution methods do not support attribute_target=True. Use regular ones instead."
"Layer attribution methods do not support attribute_target=True. Use regular attributions instead."
)
attr_pos_start, attr_pos_end = check_attribute_positions(
batch.max_generation_length,
Expand Down Expand Up @@ -524,7 +524,8 @@ def format_attribute_args(
target_ids=target_ids,
attributed_fn=attributed_fn,
attributed_fn_args=attributed_fn_args,
is_layer_attribution=self.is_layer_attribution,
attribute_batch_ids=self.attribute_batch_ids,
forward_batch_embeds=self.forward_batch_embeds,
**kwargs,
)
if self.use_baseline:
Expand Down
70 changes: 19 additions & 51 deletions inseq/attr/feat/gradient_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
""" Gradient-based feature attribution methods. """

from typing import Any, Callable, Dict, Optional
from typing import Any, Dict

import logging

Expand All @@ -27,9 +27,8 @@
Saliency,
)

from ...data import EncoderDecoderBatch, GradientFeatureAttributionStepOutput
from ...data import GradientFeatureAttributionStepOutput
from ...utils import Registry, extract_signature_args, rgetattr
from ...utils.typing import SingleScorePerStepTensor, TargetIdsTensor
from ..attribution_decorators import set_hook, unset_hook
from .attribution_utils import get_source_target_attributions
from .feature_attribution import FeatureAttribution
Expand All @@ -45,23 +44,23 @@ class GradientAttribution(FeatureAttribution, Registry):
@set_hook
def hook(self, **kwargs):
r"""
Hooks the attribution method to the model by replacing normal :obj:`nn.Embedding`
with Captum's `InterpretableEmbeddingBase <https://captum.ai/api/utilities.html#captum.attr.InterpretableEmbeddingBase>`__.
""" # noqa: E501
if self.is_layer_attribution:
Hooks the attribution method to the model by replacing normal :obj:`nn.Embedding` with Captum's
`InterpretableEmbeddingBase <https://captum.ai/api/utilities.html#captum.attr.InterpretableEmbeddingBase>`__.
"""
if self.attribute_batch_ids and not self.forward_batch_embeds:
self.target_layer = kwargs.pop("target_layer", self.attribution_model.get_embedding_layer())
logger.debug(f"target_layer={self.target_layer}")
if isinstance(self.target_layer, str):
self.target_layer = rgetattr(self.attribution_model.model, self.target_layer)
if not self.is_layer_attribution:
if not self.attribute_batch_ids:
self.attribution_model.configure_interpretable_embeddings()

@unset_hook
def unhook(self, **kwargs):
r"""
Unhook the attribution method by restoring the model's original embeddings.
"""
if self.is_layer_attribution:
if self.attribute_batch_ids and not self.forward_batch_embeds:
self.target_layer = None
else:
self.attribution_model.remove_interpretable_embeddings()
Expand Down Expand Up @@ -130,9 +129,11 @@ class DiscretizedIntegratedGradientsAttribution(GradientAttribution):

method_name = "discretized_integrated_gradients"

def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
def __init__(self, attribution_model, multiply_by_inputs: bool = False, **kwargs):
super().__init__(attribution_model, hook_to_model=False)
self.attribution_model = attribution_model
self.attribute_batch_ids = True
self.use_baseline = True
self.method = DiscretetizedIntegratedGradients(
self.attribution_model,
multiply_by_inputs,
Expand All @@ -155,42 +156,6 @@ def hook(self, **kwargs):
)
super().hook(**other_kwargs)

def format_attribute_args(
self,
batch: EncoderDecoderBatch,
target_ids: TargetIdsTensor,
attributed_fn: Callable[..., SingleScorePerStepTensor],
attribute_target: bool = False,
attributed_fn_args: Dict[str, Any] = {},
n_steps: Optional[int] = None,
strategy: Optional[str] = None,
) -> Dict[str, Any]:
attribute_fn_args = super().format_attribute_args(
batch=batch,
target_ids=target_ids,
attributed_fn=attributed_fn,
attribute_target=attribute_target,
attributed_fn_args=attributed_fn_args,
)
attribute_fn_args["inputs"] = (
self.method.path_builder.scale_inputs(
batch.sources.input_ids,
batch.sources.baseline_ids,
n_steps=n_steps,
scale_strategy=strategy,
),
)
if attribute_target:
attribute_fn_args["inputs"] += (
self.method.path_builder.scale_inputs(
batch.targets.input_ids,
batch.targets.baseline_ids,
n_steps=n_steps,
scale_strategy=strategy,
),
)
return attribute_fn_args


class IntegratedGradientsAttribution(GradientAttribution):
"""Integrated Gradients attribution method.
Expand Down Expand Up @@ -249,7 +214,8 @@ class LayerIntegratedGradientsAttribution(GradientAttribution):

def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model, hook_to_model=False)
self.is_layer_attribution = True
self.attribute_batch_ids = True
self.forward_batch_embeds = False
self.use_baseline = True
self.hook(**kwargs)
self.method = LayerIntegratedGradients(
Expand All @@ -263,14 +229,15 @@ class LayerGradientXActivationAttribution(GradientAttribution):
"""Layer Integrated Gradients attribution method.

Reference implementation:
`https://captum.ai/api/layer.html#layer-integrated-gradients <https://captum.ai/api/layer.html#layer-integrated-gradients>`__.
`https://captum.ai/api/layer.html#layer-gradient-x-activation <https://captum.ai/api/layer.html#layer-gradient-x-activation>`__.
""" # noqa E501

method_name = "layer_gradient_x_activation"

def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model, hook_to_model=False)
self.is_layer_attribution = True
self.attribute_batch_ids = True
self.forward_batch_embeds = False
self.use_baseline = False
self.hook(**kwargs)
self.method = LayerGradientXActivation(
Expand All @@ -285,13 +252,14 @@ class LayerDeepLiftAttribution(GradientAttribution):

Reference implementation:
`https://captum.ai/api/layer.html#layer-deeplift <https://captum.ai/api/layer.html#layer-deeplift>`__.
""" # noqa E501
"""

method_name = "layer_deeplift"

def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model, hook_to_model=False)
self.is_layer_attribution = True
self.attribute_batch_ids = True
self.forward_batch_embeds = False
self.use_baseline = True
self.hook(**kwargs)
self.method = LayerDeepLift(
Expand Down
93 changes: 63 additions & 30 deletions inseq/attr/feat/ops/discretized_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
_format_output,
_is_tuple,
)
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._core.integrated_gradients import IntegratedGradients
from captum.attr._utils.batching import _batch_attribution
from captum.attr._utils.common import _format_input, _reshape_and_sum
from captum.attr._utils.common import _format_input_baseline, _reshape_and_sum, _validate_input
from captum.log import log_usage
from torch import Tensor

Expand All @@ -44,7 +44,7 @@ class DiscretetizedIntegratedGradients(IntegratedGradients):
def __init__(
self,
forward_func: Callable,
multiply_by_inputs: bool = True,
multiply_by_inputs: bool = False,
) -> None:
super().__init__(forward_func, multiply_by_inputs)
self.path_builder = None
Expand All @@ -61,30 +61,66 @@ def load_monotonic_path_builder(
"""Loads the Discretized Integrated Gradients (DIG) path builder."""
self.path_builder = MonotonicPathBuilder.load(
model_name,
vocabulary_embeddings=vocabulary_embeddings,
vocabulary_embeddings=vocabulary_embeddings.to("cpu"),
special_tokens=special_tokens,
cache_dir=cache_dir,
embedding_scaling=embedding_scaling,
**kwargs,
)

@staticmethod
def get_inputs_baselines(scaled_features_tpl: Tuple[Tensor, ...], n_steps: int) -> Tuple[Tensor, ...]:
# Baseline and inputs are reversed in the path builder
# For every element in the batch, the first embedding of the sub-tensor
# of shape (n_steps x embedding_dim) is the baseline, the last is the input.
n_examples = scaled_features_tpl[0].shape[0] // n_steps
baselines = tuple(
torch.cat(
[features[i, :, :].unsqueeze(0) for i in range(0, n_steps * n_examples, n_steps)],
)
for features in scaled_features_tpl
)
inputs = tuple(
torch.cat(
[features[i, :, :].unsqueeze(0) for i in range(n_steps - 1, n_steps * n_examples, n_steps)],
)
for features in scaled_features_tpl
)
return inputs, baselines

@log_usage()
def attribute( # type: ignore
self,
inputs: MultiStepEmbeddingsTensor,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
n_steps: int = 50,
method: str = "greedy",
internal_batch_size: Union[None, int] = None,
return_convergence_delta: bool = False,
) -> Union[TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]]:
num_examples = inputs.shape[0] // n_steps
n_examples = inputs[0].shape[0]
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
is_inputs_tuple = _is_tuple(inputs)
scaled_features_tpl = _format_input(inputs)

inputs, baselines = _format_input_baseline(inputs, baselines)

_validate_input(inputs, baselines, n_steps)
scaled_features_tpl = tuple(
self.path_builder.scale_inputs(
input_tensor,
baseline_tensor,
n_steps=n_steps,
scale_strategy=method,
)
for input_tensor, baseline_tensor in zip(inputs, baselines)
)
if internal_batch_size is not None:
attributions = _batch_attribution(
self,
num_examples,
n_examples,
internal_batch_size,
n_steps,
scaled_features_tpl=scaled_features_tpl,
Expand All @@ -99,25 +135,7 @@ def attribute( # type: ignore
n_steps=n_steps,
)
if return_convergence_delta:
assert len(scaled_features_tpl) == 1, "More than one tuple not supported in this code!"
# Baseline and inputs are reversed in the path builder
# For every element in the batch, the first embedding of the sub-tensor
# of shape (n_steps x embedding_dim) is the baseline, the last is the input.
end_point = _format_input(
torch.cat(
[scaled_features_tpl[0][i, :, :].unsqueeze(0) for i in range(0, n_steps * num_examples, n_steps)],
dim=0,
)
)
start_point = _format_input(
torch.cat(
[
scaled_features_tpl[0][i, :, :].unsqueeze(0)
for i in range(n_steps - 1, n_steps * num_examples, n_steps)
],
dim=0,
)
)
start_point, end_point = self.get_inputs_baselines(scaled_features_tpl, n_steps)
# computes approximation error based on the completeness axiom
delta = self.compute_convergence_delta(
attributions,
Expand Down Expand Up @@ -152,14 +170,29 @@ def _attribute(
)
# calculate (x - x') for each interpolated point
shifted_inputs_tpl = tuple(
torch.cat([scaled_features[1:], scaled_features[-1].unsqueeze(0)])
for scaled_features in scaled_features_tpl
torch.cat(
[
torch.cat([features[idx + 1 : idx + n_steps], features[idx + n_steps - 1].unsqueeze(0)])
for idx in range(0, scaled_features_tpl[0].shape[0], n_steps)
]
)
for features in scaled_features_tpl
)
steps = tuple(shifted_inputs_tpl[i] - scaled_features_tpl[i] for i in range(len(shifted_inputs_tpl)))
scaled_grads = tuple(grads[i] * steps[i] for i in range(len(grads)))
# aggregates across all steps for each tensor in the input tuple
attributions = tuple(
# total_grads has the same dimensionality as the original inputs
total_grads = tuple(
_reshape_and_sum(scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:])
for (scaled_grad, grad) in zip(scaled_grads, grads)
)
return attributions
# computes attribution for each tensor in input_tuple
# attributions has the same dimensionality as the original inputs
if not self.multiplies_by_inputs:
return total_grads
else:
inputs, baselines = self.get_inputs_baselines(scaled_features_tpl, n_steps)
return tuple(
total_grad * (input - baseline)
for (total_grad, input, baseline) in zip(total_grads, inputs, baselines)
)
Loading