diff --git a/pyproject.toml b/pyproject.toml
index c1c733ae..2eb0b117 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,31 +52,6 @@ dynamic = ["version"]
# $ pip install quantus[tensorflow]
#
[project.optional-dependencies]
-tests = [
- "captum>=0.6.0",
- "coverage>=7.2.3",
- "flake8<=4.0.1; python_version == '3.7'",
- "flake8>=6.0.0; python_version > '3.7'",
- "pytest<=7.4.4",
- "pytest-cov>=4.0.0",
- "pytest-lazy-fixture>=0.6.3",
- "pytest-mock==3.10.0",
- "pytest_xdist",
- "tf-explain>=0.3.1",
- "keras<3",
- "zennit>=0.4.5; python_version >= '3.7'",
- "tensorflow>=2.5.0; python_version == '3.7'",
- "tensorflow>=2.12.0; sys_platform != 'darwin' and python_version > '3.7'",
- "tensorflow_macos>=2.9.0; sys_platform == 'darwin' and python_version > '3.7'",
- "torch<=1.9.0; python_version == '3.7'",
- "torch>=1.13.1; sys_platform != 'linux' and python_version > '3.7'",
- "torch>=1.13.1, <2.0.0; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
- "torch>=2.0.0; sys_platform == 'linux' and python_version >= '3.11'",
- "torchvision<=0.12.0; python_version == '3.7'",
- "torchvision>=0.15.1; sys_platform != 'linux' and python_version > '3.7'",
- "torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
- "torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'"
-]
torch = [
"torch<=1.11.0; python_version == '3.7'",
"torch>=1.13.1; sys_platform != 'linux' and python_version > '3.7'",
@@ -85,7 +60,9 @@ torch = [
"torchvision<=0.12.0; python_version == '3.7'",
"torchvision>=0.15.1; sys_platform != 'linux' and python_version > '3.7'",
"torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
- "torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'"
+ "torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'",
+ "transformers<=4.30.2; python_version == '3.7'",
+ "transformers>=4.38.2; python_version > '3.7'",
]
tensorflow = [
"tensorflow>=2.5.0; python_version == '3.7'",
@@ -104,8 +81,39 @@ zennit = [
"quantus[torch]",
"zennit>=0.5.1"
]
+transformers = [
+ "quantus[torch, tensorflow]",
+ "transformers<=4.30.2; python_version == '3.7'",
+ "transformers>=4.38.2; python_version > '3.7'",
+]
+tests = [
+ "captum>=0.6.0",
+ "coverage>=7.2.3",
+ "flake8<=4.0.1; python_version == '3.7'",
+ "flake8>=6.0.0; python_version > '3.7'",
+ "pytest<=7.4.4",
+ "pytest-cov>=4.0.0",
+ "pytest-lazy-fixture>=0.6.3",
+ "pytest-mock==3.10.0",
+ "pytest_xdist",
+ "tf-explain>=0.3.1",
+ "keras<3",
+ "zennit>=0.4.5; python_version >= '3.7'",
+ "tensorflow>=2.5.0; python_version == '3.7'",
+ "tensorflow>=2.12.0; sys_platform != 'darwin' and python_version > '3.7'",
+ "tensorflow_macos>=2.9.0; sys_platform == 'darwin' and python_version > '3.7'",
+ "torch<=1.9.0; python_version == '3.7'",
+ "torch>=1.13.1; sys_platform != 'linux' and python_version > '3.7'",
+ "torch>=1.13.1, <2.0.0; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
+ "torch>=2.0.0; sys_platform == 'linux' and python_version >= '3.11'",
+ "torchvision<=0.12.0; python_version == '3.7'",
+ "torchvision>=0.15.1; sys_platform != 'linux' and python_version > '3.7'",
+ "torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
+ "torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'",
+ "quantus[captum,tf_explain,zennit, transformers]"
+]
full = [
- "quantus[captum,tf_explain,zennit]"
+ "quantus[captum,tf_explain,zennit, transformers]"
]
[build-system]
diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py
index c60d90a5..4760c81b 100644
--- a/quantus/helpers/model/pytorch_model.py
+++ b/quantus/helpers/model/pytorch_model.py
@@ -6,16 +6,19 @@
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see .
# Quantus project URL: .
import copy
+import logging
+import warnings
from contextlib import suppress
from copy import deepcopy
-from typing import Any, Dict, Optional, Tuple, List, Union, Generator
-import warnings
-import logging
+from functools import lru_cache
+from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
import numpy as np
+import numpy.typing as npt
import torch
from torch import nn
-from functools import lru_cache
+from transformers import PreTrainedModel
+from transformers.tokenization_utils import BatchEncoding
from quantus.helpers import utils
from quantus.helpers.model.model_interface import ModelInterface
@@ -97,7 +100,33 @@ def _get_model_with_linear_top(self) -> torch.nn:
return linear_model
- def get_softmax_arg_model(self) -> torch.nn:
+ def _obtain_predictions(self, x, model_predict_kwargs):
+ pred = None
+ if isinstance(self.model, PreTrainedModel):
+ # BatchEncoding is the default output from Tokenizers which contains
+ # necessary keys such as `input_ids` and `attention_mask`.
+ # It is also possible to pass a Dict with those keys.
+ if not (
+ isinstance(x, BatchEncoding)
+ or (
+ isinstance(x, dict)
+ and ("input_ids" in x.keys() and "attention_mask" in x.keys())
+ )
+ ):
+ raise ValueError(
+ "When using HuggingFace pretrained models, please use Tokenizers output for `x` "
+ "or make sure you're passing a dict with input_ids and attention_mask as keys"
+ )
+ pred = self.model(**x, **model_predict_kwargs).logits
+ if self.softmax:
+ return torch.softmax(pred, dim=-1)
+ return pred
+ elif isinstance(self.model, nn.Module):
+ pred_model = self.get_softmax_arg_model()
+ return pred_model(torch.Tensor(x).to(self.device), **model_predict_kwargs)
+ raise ValueError("Predictions cant be null")
+
+ def get_softmax_arg_model(self) -> torch.nn.Module:
"""
Returns model with last layer adjusted accordingly to softmax argument.
If the original model has softmax activation as the last layer and softmax=false,
@@ -156,14 +185,20 @@ def get_softmax_arg_model(self) -> torch.nn:
return self.model # Case 5
- def predict(self, x: np.ndarray, grad: bool = False, **kwargs) -> np.array:
+ def predict(
+ self,
+ x: Union[npt.ArrayLike, Mapping[str, npt.ArrayLike]],
+ grad: bool = False,
+ **kwargs,
+ ) -> np.ndarray:
"""
Predict on the given input.
Parameters
----------
- x: np.ndarray
- A given input that the wrapped model predicts on.
+ x: np.ndarray, BatchEncoding
+ A given input that the wrapped model predicts on. This can be either a numpy
+ or a BatchEncoding (Tokenizers output from huggingface's Tokenizer library)
grad: boolean
Indicates if gradient-calculation is disabled or not.
kwargs: optional
@@ -177,15 +212,13 @@ def predict(self, x: np.ndarray, grad: bool = False, **kwargs) -> np.array:
# Use kwargs of predict call if specified, but don't overwrite object attribute
model_predict_kwargs = {**self.model_predict_kwargs, **kwargs}
-
if self.model.training:
raise AttributeError("Torch model needs to be in the evaluation mode.")
grad_context = torch.no_grad() if not grad else suppress()
with grad_context:
- pred_model = self.get_softmax_arg_model()
- pred = pred_model(torch.Tensor(x).to(self.device), **model_predict_kwargs)
+ pred = self._obtain_predictions(x, model_predict_kwargs)
if pred.requires_grad:
return pred.detach().cpu().numpy()
return pred.cpu().numpy()
@@ -265,9 +298,9 @@ def get_random_layer_generator(
random_layer_model = deepcopy(self.model)
modules = [
- l
- for l in random_layer_model.named_modules()
- if (hasattr(l[1], "reset_parameters"))
+ layer
+ for layer in random_layer_model.named_modules()
+ if (hasattr(layer[1], "reset_parameters"))
]
if order == "top_down":
@@ -350,7 +383,7 @@ def add_mean_shift_to_first_layer(
with torch.no_grad():
new_model = deepcopy(self.model)
- modules = [l for l in new_model.named_modules()]
+ modules = [layer for layer in new_model.named_modules()]
module = modules[1]
delta = torch.zeros(size=shape).fill_(input_shift)
@@ -377,8 +410,8 @@ def get_hidden_representations(
"""
Compute the model's internal representation of input x.
In practice, this means, executing a forward pass and then, capturing the output of layers (of interest).
- As the exact definition of "internal model representation" is left out in the original paper (see: https://arxiv.org/pdf/2203.06877.pdf),
- we make the implementation flexible.
+ As the exact definition of "internal model representation" is left out in the original paper
+ (see: https://arxiv.org/pdf/2203.06877.pdf), we make the implementation flexible.
It is up to the user whether all layers are used, or specific ones should be selected.
The user can therefore select a layer by providing 'layer_names' (exclusive) or 'layer_indices'.
@@ -422,8 +455,8 @@ def is_layer_of_interest(layer_index: int, layer_name: str):
# skip modules defined by subclassing API.
hidden_layers = list( # type: ignore
filter(
- lambda l: not isinstance(
- l[1], (self.model.__class__, torch.nn.Sequential)
+ lambda layer: not isinstance(
+ layer[1], (self.model.__class__, torch.nn.Sequential)
),
all_layers,
)
diff --git a/tests/conftest.py b/tests/conftest.py
index c81be012..6e87b554 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,26 +1,28 @@
-import pytest
+import os
import pickle
-import torch
+
import numpy as np
-from keras.datasets import cifar10
import pandas as pd
+import pytest
+import torch
+from keras.datasets import cifar10
+from quantus.helpers.model.models import (CifarCNNModel, ConvNet1D,
+ ConvNet1DTF, LeNet, LeNetTF,
+ TitanicSimpleTFModel,
+ TitanicSimpleTorchModel)
from sklearn.model_selection import train_test_split
-import os
-
-from quantus.helpers.model.models import (
- LeNet,
- LeNetTF,
- CifarCNNModel,
- ConvNet1D,
- ConvNet1DTF,
- TitanicSimpleTFModel,
- TitanicSimpleTorchModel,
-)
+from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
+ set_seed)
CIFAR_IMAGE_SIZE = 32
MNIST_IMAGE_SIZE = 28
BATCH_SIZE = 124
MINI_BATCH_SIZE = 8
+RANDOM_SEED = 42
+
+@pytest.fixture(scope='function', autouse=True)
+def reset_prngs():
+ set_seed(42)
@pytest.fixture(scope="session", autouse=True)
@@ -236,7 +238,30 @@ def load_mnist_model_softmax():
return model
+@pytest.fixture(scope="session", autouse=False)
+def load_hf_distilbert_sequence_classifier():
+ """
+ TODO
+ """
+ DISTILBERT_BASE = "distilbert-base-uncased"
+ model = AutoModelForSequenceClassification.from_pretrained(
+ DISTILBERT_BASE, cache_dir="/tmp/"
+ )
+ return model
+
+
+@pytest.fixture(scope="session", autouse=False)
+def dummy_hf_tokenizer():
+ """
+ TODO
+ """
+ DISTILBERT_BASE = "distilbert-base-uncased"
+ REFERENCE_TEXT = "The quick brown fox jumps over the lazy dog"
+ tokenizer = AutoTokenizer.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/")
+ return tokenizer(REFERENCE_TEXT, return_tensors="pt")
+
+
@pytest.fixture(scope="session", autouse=True)
def set_env():
"""Set ENV var, so test outputs are not polluted by progress bars and warnings."""
- os.environ["PYTEST"] = "1"
+ os.environ["PYTEST"] = "1"
\ No newline at end of file
diff --git a/tests/functions/test_pytorch_model.py b/tests/functions/test_pytorch_model.py
index 53fb0405..f995fb8c 100644
--- a/tests/functions/test_pytorch_model.py
+++ b/tests/functions/test_pytorch_model.py
@@ -1,13 +1,13 @@
from collections import OrderedDict
+from contextlib import nullcontext
from typing import Union
import numpy as np
import pytest
import torch
from pytest_lazyfixture import lazy_fixture
-from scipy.special import softmax
-
from quantus.helpers.model.pytorch_model import PyTorchModel
+from scipy.special import softmax
@pytest.fixture
@@ -242,3 +242,52 @@ def test_add_mean_shift_to_first_layer(load_mnist_model):
a1 = model.model(X)
a2 = new_model(X_shift)
assert torch.all(torch.isclose(a1, a2, atol=1e-04))
+
+
+@pytest.mark.pytorch_model
+@pytest.mark.parametrize(
+ "hf_model,data,softmax,model_kwargs,expected",
+ [
+ (
+ lazy_fixture("load_hf_distilbert_sequence_classifier"),
+ lazy_fixture("dummy_hf_tokenizer"),
+ False,
+ {},
+ nullcontext(np.array([[0.00424026, -0.03878461]])),
+ ),
+ (
+ lazy_fixture("load_hf_distilbert_sequence_classifier"),
+ lazy_fixture("dummy_hf_tokenizer"),
+ False,
+ {"labels": torch.tensor([1]), "output_hidden_states": True},
+ nullcontext(np.array([[0.00424026, -0.03878461]])),
+ ),
+ (
+ lazy_fixture("load_hf_distilbert_sequence_classifier"),
+ {'input_ids': torch.tensor([[ 101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899,
+ 102]]), 'attention_mask': torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])},
+ False,
+ {"labels": torch.tensor([1]), "output_hidden_states": True},
+ nullcontext(np.array([[0.00424026, -0.03878461]])),
+ ),
+ (
+ lazy_fixture("load_hf_distilbert_sequence_classifier"),
+ lazy_fixture("dummy_hf_tokenizer"),
+ True,
+ {},
+ nullcontext(np.array([[0.51075452, 0.4892454]])),
+ ),
+ (
+ lazy_fixture("load_hf_distilbert_sequence_classifier"),
+ np.array([1, 2, 3]),
+ False,
+ {},
+ pytest.raises(ValueError),
+ ),
+ ],
+)
+def test_huggingface_classifier_predict(hf_model, data, softmax, model_kwargs, expected):
+ model = PyTorchModel(model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs)
+ with expected:
+ out = model.predict(x=data)
+ assert np.allclose(out, expected.enter_result), "Test failed."
diff --git a/tox.ini b/tox.ini
index b1d3c9b0..ea578d32 100644
--- a/tox.ini
+++ b/tox.ini
@@ -66,3 +66,7 @@ python =
3.9 = py39
3.10 = py310
3.11 = py311
+
+[flake8]
+max-line-length = 127
+max-complexity = 10
\ No newline at end of file