Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate AutoModelForSequenceClassification through PytorchModel #339

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 35 additions & 27 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,6 @@ dynamic = ["version"]
# $ pip install quantus[tensorflow]
#
[project.optional-dependencies]
tests = [
"captum>=0.6.0",
"coverage>=7.2.3",
"flake8<=4.0.1; python_version == '3.7'",
"flake8>=6.0.0; python_version > '3.7'",
"pytest<=7.4.4",
"pytest-cov>=4.0.0",
"pytest-lazy-fixture>=0.6.3",
"pytest-mock==3.10.0",
"pytest_xdist",
"tf-explain>=0.3.1",
"keras<3",
"zennit>=0.4.5; python_version >= '3.7'",
"tensorflow>=2.5.0; python_version == '3.7'",
"tensorflow>=2.12.0; sys_platform != 'darwin' and python_version > '3.7'",
"tensorflow_macos>=2.9.0; sys_platform == 'darwin' and python_version > '3.7'",
"torch<=1.9.0; python_version == '3.7'",
"torch>=1.13.1; sys_platform != 'linux' and python_version > '3.7'",
"torch>=1.13.1, <2.0.0; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
"torch>=2.0.0; sys_platform == 'linux' and python_version >= '3.11'",
"torchvision<=0.12.0; python_version == '3.7'",
"torchvision>=0.15.1; sys_platform != 'linux' and python_version > '3.7'",
"torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
"torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'"
]
torch = [
"torch<=1.11.0; python_version == '3.7'",
"torch>=1.13.1; sys_platform != 'linux' and python_version > '3.7'",
Expand All @@ -85,7 +60,9 @@ torch = [
"torchvision<=0.12.0; python_version == '3.7'",
"torchvision>=0.15.1; sys_platform != 'linux' and python_version > '3.7'",
"torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
"torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'"
"torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'",
"transformers<=4.30.2; python_version == '3.7'",
aaarrti marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove transformers from torch = [...] section

"transformers>=4.38.2; python_version > '3.7'",
]
tensorflow = [
"tensorflow>=2.5.0; python_version == '3.7'",
Expand All @@ -104,8 +81,39 @@ zennit = [
"quantus[torch]",
"zennit>=0.5.1"
]
transformers = [
"quantus[torch, tensorflow]",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantus[torch] should be enough

"transformers<=4.30.2; python_version == '3.7'",
"transformers>=4.38.2; python_version > '3.7'",
]
tests = [
"captum>=0.6.0",
"coverage>=7.2.3",
"flake8<=4.0.1; python_version == '3.7'",
"flake8>=6.0.0; python_version > '3.7'",
"pytest<=7.4.4",
"pytest-cov>=4.0.0",
"pytest-lazy-fixture>=0.6.3",
"pytest-mock==3.10.0",
"pytest_xdist",
"tf-explain>=0.3.1",
"keras<3",
"zennit>=0.4.5; python_version >= '3.7'",
"tensorflow>=2.5.0; python_version == '3.7'",
"tensorflow>=2.12.0; sys_platform != 'darwin' and python_version > '3.7'",
"tensorflow_macos>=2.9.0; sys_platform == 'darwin' and python_version > '3.7'",
"torch<=1.9.0; python_version == '3.7'",
"torch>=1.13.1; sys_platform != 'linux' and python_version > '3.7'",
"torch>=1.13.1, <2.0.0; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
"torch>=2.0.0; sys_platform == 'linux' and python_version >= '3.11'",
"torchvision<=0.12.0; python_version == '3.7'",
"torchvision>=0.15.1; sys_platform != 'linux' and python_version > '3.7'",
"torchvision>=0.14.0, <0.15.1; sys_platform == 'linux' and python_version > '3.7' and python_version <= '3.10'",
"torchvision>=0.15.1; sys_platform == 'linux' and python_version >= '3.11'",
"quantus[captum,tf_explain,zennit, transformers]"
]
full = [
"quantus[captum,tf_explain,zennit]"
"quantus[captum,tf_explain,zennit, transformers]"
]

[build-system]
Expand Down
71 changes: 52 additions & 19 deletions quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.
import copy
import logging
import warnings
from contextlib import suppress
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, List, Union, Generator
import warnings
import logging
from functools import lru_cache
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
import torch
from torch import nn
from functools import lru_cache
from transformers import PreTrainedModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause ModuleNotFoundError when user tries to import Quantus without transformers installed.

from transformers.tokenization_utils import BatchEncoding

from quantus.helpers import utils
from quantus.helpers.model.model_interface import ModelInterface
Expand Down Expand Up @@ -97,7 +100,33 @@ def _get_model_with_linear_top(self) -> torch.nn:

return linear_model

def get_softmax_arg_model(self) -> torch.nn:
def _obtain_predictions(self, x, model_predict_kwargs):
pred = None
if isinstance(self.model, PreTrainedModel):
# BatchEncoding is the default output from Tokenizers which contains
# necessary keys such as `input_ids` and `attention_mask`.
# It is also possible to pass a Dict with those keys.
if not (
isinstance(x, BatchEncoding)
or (
isinstance(x, dict)
and ("input_ids" in x.keys() and "attention_mask" in x.keys())
)
):
raise ValueError(
"When using HuggingFace pretrained models, please use Tokenizers output for `x` "
"or make sure you're passing a dict with input_ids and attention_mask as keys"
)
pred = self.model(**x, **model_predict_kwargs).logits
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also enable softmax here (post accessing the logits)? so that we convert the pred to softmax, if softmax=True? (we can add it as a class attribute above).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented this way, please see if you agree: 9a67c4c

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return self.model(**x, **model_predict_kwargs).logits

Copy link
Collaborator Author

@abarbosa94 abarbosa94 Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did slightly different (in 179da1e) to handle and raise the softmax param properly. Could you see if you agree? Thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that looks great :D @abarbosa94

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, it looks a bit different now. Can we also remove pred = None at the top?

if self.softmax:
return torch.softmax(pred, dim=-1)
return pred
elif isinstance(self.model, nn.Module):
pred_model = self.get_softmax_arg_model()
return pred_model(torch.Tensor(x).to(self.device), **model_predict_kwargs)
raise ValueError("Predictions cant be null")

def get_softmax_arg_model(self) -> torch.nn.Module:
"""
Returns model with last layer adjusted accordingly to softmax argument.
If the original model has softmax activation as the last layer and softmax=false,
Expand Down Expand Up @@ -156,14 +185,20 @@ def get_softmax_arg_model(self) -> torch.nn:

return self.model # Case 5

def predict(self, x: np.ndarray, grad: bool = False, **kwargs) -> np.array:
def predict(
aaarrti marked this conversation as resolved.
Show resolved Hide resolved
self,
x: Union[npt.ArrayLike, Mapping[str, npt.ArrayLike]],
grad: bool = False,
**kwargs,
) -> np.ndarray:
"""
Predict on the given input.

Parameters
----------
x: np.ndarray
A given input that the wrapped model predicts on.
x: np.ndarray, BatchEncoding
A given input that the wrapped model predicts on. This can be either a numpy
or a BatchEncoding (Tokenizers output from huggingface's Tokenizer library)
grad: boolean
Indicates if gradient-calculation is disabled or not.
kwargs: optional
Expand All @@ -177,15 +212,13 @@ def predict(self, x: np.ndarray, grad: bool = False, **kwargs) -> np.array:

# Use kwargs of predict call if specified, but don't overwrite object attribute
model_predict_kwargs = {**self.model_predict_kwargs, **kwargs}

if self.model.training:
raise AttributeError("Torch model needs to be in the evaluation mode.")

grad_context = torch.no_grad() if not grad else suppress()

with grad_context:
pred_model = self.get_softmax_arg_model()
pred = pred_model(torch.Tensor(x).to(self.device), **model_predict_kwargs)
pred = self._obtain_predictions(x, model_predict_kwargs)
if pred.requires_grad:
return pred.detach().cpu().numpy()
return pred.cpu().numpy()
Expand Down Expand Up @@ -265,9 +298,9 @@ def get_random_layer_generator(
random_layer_model = deepcopy(self.model)

modules = [
l
for l in random_layer_model.named_modules()
if (hasattr(l[1], "reset_parameters"))
layer
for layer in random_layer_model.named_modules()
if (hasattr(layer[1], "reset_parameters"))
]

if order == "top_down":
Expand Down Expand Up @@ -350,7 +383,7 @@ def add_mean_shift_to_first_layer(
with torch.no_grad():
new_model = deepcopy(self.model)

modules = [l for l in new_model.named_modules()]
modules = [layer for layer in new_model.named_modules()]
module = modules[1]

delta = torch.zeros(size=shape).fill_(input_shift)
Expand All @@ -377,8 +410,8 @@ def get_hidden_representations(
"""
Compute the model's internal representation of input x.
In practice, this means, executing a forward pass and then, capturing the output of layers (of interest).
As the exact definition of "internal model representation" is left out in the original paper (see: https://arxiv.org/pdf/2203.06877.pdf),
we make the implementation flexible.
As the exact definition of "internal model representation" is left out in the original paper
(see: https://arxiv.org/pdf/2203.06877.pdf), we make the implementation flexible.
It is up to the user whether all layers are used, or specific ones should be selected.
The user can therefore select a layer by providing 'layer_names' (exclusive) or 'layer_indices'.

Expand Down Expand Up @@ -422,8 +455,8 @@ def is_layer_of_interest(layer_index: int, layer_name: str):
# skip modules defined by subclassing API.
hidden_layers = list( # type: ignore
filter(
lambda l: not isinstance(
l[1], (self.model.__class__, torch.nn.Sequential)
lambda layer: not isinstance(
layer[1], (self.model.__class__, torch.nn.Sequential)
),
all_layers,
)
Expand Down
55 changes: 40 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
import pytest
import os
import pickle
import torch

import numpy as np
from keras.datasets import cifar10
import pandas as pd
import pytest
import torch
from keras.datasets import cifar10
from quantus.helpers.model.models import (CifarCNNModel, ConvNet1D,
ConvNet1DTF, LeNet, LeNetTF,
TitanicSimpleTFModel,
TitanicSimpleTorchModel)
from sklearn.model_selection import train_test_split
import os

from quantus.helpers.model.models import (
LeNet,
LeNetTF,
CifarCNNModel,
ConvNet1D,
ConvNet1DTF,
TitanicSimpleTFModel,
TitanicSimpleTorchModel,
)
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
set_seed)

CIFAR_IMAGE_SIZE = 32
MNIST_IMAGE_SIZE = 28
BATCH_SIZE = 124
MINI_BATCH_SIZE = 8
RANDOM_SEED = 42

@pytest.fixture(scope='function', autouse=True)
def reset_prngs():
set_seed(42)


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -236,7 +238,30 @@ def load_mnist_model_softmax():
return model


@pytest.fixture(scope="session", autouse=False)
def load_hf_distilbert_sequence_classifier():
"""
TODO
"""
DISTILBERT_BASE = "distilbert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(
DISTILBERT_BASE, cache_dir="/tmp/"
)
return model


@pytest.fixture(scope="session", autouse=False)
def dummy_hf_tokenizer():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

autouse=False is the default

"""
TODO
"""
DISTILBERT_BASE = "distilbert-base-uncased"
REFERENCE_TEXT = "The quick brown fox jumps over the lazy dog"
tokenizer = AutoTokenizer.from_pretrained(DISTILBERT_BASE, cache_dir="/tmp/")
return tokenizer(REFERENCE_TEXT, return_tensors="pt")


@pytest.fixture(scope="session", autouse=True)
def set_env():
"""Set ENV var, so test outputs are not polluted by progress bars and warnings."""
os.environ["PYTEST"] = "1"
os.environ["PYTEST"] = "1"
53 changes: 51 additions & 2 deletions tests/functions/test_pytorch_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from collections import OrderedDict
from contextlib import nullcontext
from typing import Union

import numpy as np
import pytest
import torch
from pytest_lazyfixture import lazy_fixture
from scipy.special import softmax

from quantus.helpers.model.pytorch_model import PyTorchModel
from scipy.special import softmax


@pytest.fixture
Expand Down Expand Up @@ -242,3 +242,52 @@ def test_add_mean_shift_to_first_layer(load_mnist_model):
a1 = model.model(X)
a2 = new_model(X_shift)
assert torch.all(torch.isclose(a1, a2, atol=1e-04))


@pytest.mark.pytorch_model
@pytest.mark.parametrize(
"hf_model,data,softmax,model_kwargs,expected",
[
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
lazy_fixture("dummy_hf_tokenizer"),
False,
{},
nullcontext(np.array([[0.00424026, -0.03878461]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
lazy_fixture("dummy_hf_tokenizer"),
False,
{"labels": torch.tensor([1]), "output_hidden_states": True},
nullcontext(np.array([[0.00424026, -0.03878461]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
{'input_ids': torch.tensor([[ 101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899,
102]]), 'attention_mask': torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])},
False,
{"labels": torch.tensor([1]), "output_hidden_states": True},
nullcontext(np.array([[0.00424026, -0.03878461]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
lazy_fixture("dummy_hf_tokenizer"),
True,
{},
nullcontext(np.array([[0.51075452, 0.4892454]])),
),
(
lazy_fixture("load_hf_distilbert_sequence_classifier"),
np.array([1, 2, 3]),
False,
{},
pytest.raises(ValueError),
),
],
)
def test_huggingface_classifier_predict(hf_model, data, softmax, model_kwargs, expected):
aaarrti marked this conversation as resolved.
Show resolved Hide resolved
model = PyTorchModel(model=hf_model, softmax=softmax, model_predict_kwargs=model_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I though softmax must be a bool, or?

with expected:
out = model.predict(x=data)
assert np.allclose(out, expected.enter_result), "Test failed."
4 changes: 4 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ python =
3.9 = py39
3.10 = py310
3.11 = py311

[flake8]
max-line-length = 127
aaarrti marked this conversation as resolved.
Show resolved Hide resolved
max-complexity = 10
aaarrti marked this conversation as resolved.
Show resolved Hide resolved
Loading