Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify metric API #294

Merged
merged 62 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
e0e090c
WIP
aaarrti Aug 12, 2023
8576b4d
WIP
aaarrti Aug 12, 2023
bafbd93
WIP
aaarrti Aug 12, 2023
764e415
WIP
aaarrti Aug 12, 2023
e8b9f71
WIP
aaarrti Aug 13, 2023
ae8e079
WIP
aaarrti Aug 13, 2023
c16ff51
WIP
aaarrti Aug 13, 2023
61baaa0
WIP
aaarrti Aug 13, 2023
121f0ef
minify git diff
aaarrti Aug 30, 2023
394d417
minify git diff
aaarrti Aug 30, 2023
41b8051
minify git diff
aaarrti Aug 30, 2023
4582f14
Merge branch 'main' into unfiy-metric-api
aaarrti Aug 30, 2023
2dff480
minify git diff
aaarrti Aug 30, 2023
e968896
minify git diff
aaarrti Aug 30, 2023
7acfd85
minify git diff
aaarrti Aug 30, 2023
d08dc0f
minify git diff
aaarrti Aug 30, 2023
6989748
mypy fixes
aaarrti Aug 30, 2023
4d1d25e
add missing docstrings
aaarrti Aug 30, 2023
58d19ff
* run black
aaarrti Oct 5, 2023
07790d5
Revert "* run black"
aaarrti Oct 5, 2023
858d6be
* run black
aaarrti Oct 5, 2023
c4ab6f8
*
aaarrti Oct 5, 2023
6eef978
*
aaarrti Oct 5, 2023
9ca2b7d
* code review comments
aaarrti Oct 5, 2023
9c4b34e
* code review comments
aaarrti Oct 5, 2023
c1eeae3
* code review comments
aaarrti Oct 5, 2023
77337fb
* code review comments
aaarrti Oct 5, 2023
02850ef
Update base.py
annahedstroem Oct 9, 2023
ecda3c5
* code review comments
aaarrti Oct 10, 2023
c44e2f1
* code review comments
aaarrti Oct 11, 2023
bcb6f8e
* test fixes
aaarrti Oct 11, 2023
e263a4e
* test fixes
aaarrti Oct 11, 2023
30a5362
Merge remote-tracking branch 'origin/unfiy-metric-api' into unfiy-met…
aaarrti Oct 11, 2023
ccaab33
* test fixes
aaarrti Oct 11, 2023
c2f368d
* cleanup
aaarrti Oct 11, 2023
367ad25
* test fix
aaarrti Oct 11, 2023
22c3b0b
*
aaarrti Oct 11, 2023
91c6de5
*
aaarrti Oct 11, 2023
1cce097
*
aaarrti Oct 11, 2023
f16bcea
*
aaarrti Oct 11, 2023
9809d63
* mypy fixes
aaarrti Oct 17, 2023
e37de15
* add xfail
aaarrti Oct 17, 2023
fdfe768
* add xfail
aaarrti Oct 17, 2023
adf3482
* add xfail
aaarrti Oct 17, 2023
f85310a
* add xfail
aaarrti Oct 17, 2023
0433a49
* add xfail
aaarrti Oct 17, 2023
e537708
*
aaarrti Oct 17, 2023
f0f90d9
* revert typing changes, update docs
aaarrti Oct 18, 2023
aea87ad
* cleanup
aaarrti Oct 18, 2023
c23a0ee
* typing fix
aaarrti Oct 19, 2023
31d4ef9
Merge branch 'main' into unfiy-metric-api
aaarrti Oct 24, 2023
36dcd75
* typing fix
aaarrti Oct 24, 2023
10a78bc
Merge branch 'main' into unfiy-metric-api
aaarrti Oct 27, 2023
d0abbf8
* cleanup
artem-sereda Oct 27, 2023
2124abf
* bump up the version
artem-sereda Oct 27, 2023
88e1ad0
code review fixes
aaarrti Nov 3, 2023
fbefe31
* code review comments
aaarrti Nov 3, 2023
27e908b
* code review comments
aaarrti Nov 3, 2023
8284075
* code review comments
aaarrti Nov 3, 2023
669c8ff
* remove TypedDict
aaarrti Nov 3, 2023
0f535f0
*
aaarrti Nov 3, 2023
a268d10
mypy fixes
aaarrti Nov 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
pull_request:
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }}
cancel-in-progress: true

jobs:
run:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }}
cancel-in-progress: true

jobs:
lint:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ on:
pull_request:
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }}
cancel-in-progress: true


jobs:
Expand Down
8 changes: 6 additions & 2 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ BUILDDIR = build
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile
.PHONY: help Makefile clean

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

rst:
rst: clean
@sphinx-apidoc -o source/docs_api ../quantus --module-first --separate --force


clean:
rm -rf source/docs_api
7 changes: 7 additions & 0 deletions docs/source/docs_api/quantus.helpers.enums.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.helpers.enums module
============================

.. automodule:: quantus.helpers.enums
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/docs_api/quantus.helpers.perturbation_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.helpers.perturbation\_utils module
==========================================

.. automodule:: quantus.helpers.perturbation_utils
:members:
:undoc-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/source/docs_api/quantus.helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Submodules

quantus.helpers.asserts
quantus.helpers.constants
quantus.helpers.enums
quantus.helpers.perturbation_utils
quantus.helpers.plotting
quantus.helpers.utils
quantus.helpers.warn
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ignore_missing_imports = True
no_site_packages = True
show_none_errors = False
ignore_errors = False
plugins = numpy.typing.mypy_plugin

[mypy-quantus.*]
disallow_untyped_defs = False
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"scipy>=1.7.3",
"tqdm>=4.62.3",
"matplotlib>=3.3.4",
"typing_extensions; python_version <= '3.8'"
]

dynamic = ["version"]
Expand Down
11 changes: 8 additions & 3 deletions quantus/helpers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

from typing import List, Dict

import sys
from typing import List, Dict, Mapping, Type
from quantus.functions.loss_func import *
from quantus.functions.normalise_func import *
from quantus.functions.perturb_func import *
from quantus.functions.similarity_func import *
from quantus.metrics import *

if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final


AVAILABLE_METRICS = {
AVAILABLE_METRICS: Final[Mapping[str, Mapping[str, Type[Metric]]]] = {
aaarrti marked this conversation as resolved.
Show resolved Hide resolved
"Faithfulness": {
"Faithfulness Correlation": FaithfulnessCorrelation,
"Faithfulness Estimate": FaithfulnessEstimate,
Expand Down
1 change: 0 additions & 1 deletion quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ def add_mean_shift_to_first_layer(
The resulting model with a shifted first layer.
"""
with torch.no_grad():

new_model = deepcopy(self.model)

modules = [l for l in new_model.named_modules()]
Expand Down
85 changes: 85 additions & 0 deletions quantus/helpers/perturbation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

import sys
from typing import List, TYPE_CHECKING, Callable, Mapping
import numpy as np
import functools

if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol


if TYPE_CHECKING:
from quantus.helpers.model.model_interface import ModelInterface

class PerturbFunc(Protocol):
def __call__(
self,
arr: np.ndarray,
indices: np.ndarray,
indexed_axes: np.ndarray,
**kwargs,
) -> np.ndarray:
...


def make_perturb_func(
perturb_func: PerturbFunc, perturb_func_kwargs: Mapping[str, ...] | None, **kwargs
) -> PerturbFunc | functools.partial:
"""A utility function to save few lines of code during perturbation metric initialization."""
if perturb_func_kwargs is not None:
func_kwargs = kwargs.copy()
func_kwargs.update(perturb_func_kwargs)
else:
func_kwargs = kwargs

return functools.partial(perturb_func, **func_kwargs)


def make_changed_prediction_indices_func(
return_nan_when_prediction_changes: bool,
) -> Callable[[ModelInterface, np.ndarray, np.ndarray], List[int]]:
"""A utility function to improve static analysis."""
return functools.partial(
changed_prediction_indices,
return_nan_when_prediction_changes=return_nan_when_prediction_changes,
)


def changed_prediction_indices(
model: ModelInterface,
x_batch: np.ndarray,
x_perturbed: np.ndarray,
return_nan_when_prediction_changes: bool,
) -> List[int]:
"""
Find indices in batch, for which predicted label has changed after applying perturbation.
If metric `return_nan_when_prediction_changes` is False, will return empty list.

Parameters
----------
return_nan_when_prediction_changes:
Instance attribute of perturbation metrics.
model:
x_batch:
Batch of original inputs provided by user.
x_perturbed:
Batch of inputs after applying perturbation.

Returns
-------

changed_idx:
List of indices in batch, for which predicted label has changed afer.

"""

if not return_nan_when_prediction_changes:
return []

labels_before = model.predict(x_batch).argmax(axis=-1)
labels_after = model.predict(x_perturbed).argmax(axis=-1)
changed_idx = np.reshape(np.argwhere(labels_before != labels_after), -1)
return changed_idx.tolist()
4 changes: 1 addition & 3 deletions quantus/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

from quantus.metrics.base import *
from quantus.metrics.base_batched import *
from quantus.metrics.base_perturbed import *
from quantus.metrics.axiomatic import *
from quantus.metrics.base import Metric
from quantus.metrics.complexity import *
from quantus.metrics.faithfulness import *
from quantus.metrics.localisation import *
Expand Down
85 changes: 60 additions & 25 deletions quantus/metrics/axiomatic/completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,32 @@
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np

from quantus.helpers import warn
from quantus.helpers import asserts
from quantus.helpers.model.model_interface import ModelInterface
from quantus.functions.normalise_func import normalise_by_max
from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.metrics.base_perturbed import PerturbationMetric
from quantus.helpers import warn
from quantus.helpers.enums import (
ModelType,
DataType,
ScoreDirection,
EvaluationCategory,
ModelType,
ScoreDirection,
)
from quantus.helpers.model.model_interface import ModelInterface
from quantus.helpers.perturbation_utils import make_perturb_func
from quantus.metrics.base import Metric

if sys.version_info >= (3, 8):
from typing import final
else:
from typing_extensions import final


class Completeness(PerturbationMetric):
@final
class Completeness(Metric[List[float]]):
"""
Implementation of Completeness test by Sundararajan et al., 2017, also referred
to as Summation to Delta by Shrikumar et al., 2017 and Conservation by
Expand Down Expand Up @@ -65,7 +73,7 @@ def __init__(
normalise_func_kwargs: Optional[Dict[str, Any]] = None,
output_func: Optional[Callable] = lambda x: x,
perturb_baseline: str = "black",
perturb_func: Callable = None,
perturb_func: Callable = baseline_replacement_by_indices,
aaarrti marked this conversation as resolved.
Show resolved Hide resolved
perturb_func_kwargs: Optional[Dict[str, Any]] = None,
return_aggregate: bool = False,
aggregate_func: Callable = np.mean,
Expand Down Expand Up @@ -114,21 +122,11 @@ def __init__(
"""
if normalise_func is None:
normalise_func = normalise_by_max

if perturb_func is None:
perturb_func = baseline_replacement_by_indices

if perturb_func_kwargs is None:
perturb_func_kwargs = {}
perturb_func_kwargs["perturb_baseline"] = perturb_baseline

super().__init__(
abs=abs,
normalise=normalise,
normalise_func=normalise_func,
normalise_func_kwargs=normalise_func_kwargs,
perturb_func=perturb_func,
perturb_func_kwargs=perturb_func_kwargs,
return_aggregate=return_aggregate,
aggregate_func=aggregate_func,
default_plot_func=default_plot_func,
Expand All @@ -141,6 +139,9 @@ def __init__(
if output_func is None:
output_func = lambda x: x
self.output_func = output_func
self.perturb_func = make_perturb_func(
perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline
)

# Asserts and warnings.
if not self.disable_warnings:
Expand Down Expand Up @@ -268,7 +269,6 @@ def evaluate_instance(
x: np.ndarray,
y: np.ndarray,
a: np.ndarray,
s: np.ndarray,
) -> bool:
"""
Evaluate instance gets model and data for a single instance as input and returns the evaluation result.
Expand All @@ -283,19 +283,14 @@ def evaluate_instance(
The output to be evaluated on an instance-basis.
a: np.ndarray
The explanation to be evaluated on an instance-basis.
s: np.ndarray
The segmentation to be evaluated on an instance-basis.

Returns
-------
: boolean
The evaluation results.
"""
x_baseline = self.perturb_func(
arr=x,
indices=np.arange(0, x.size),
indexed_axes=np.arange(0, x.ndim),
**self.perturb_func_kwargs,
arr=x, indices=np.arange(0, x.size), indexed_axes=np.arange(0, x.ndim)
)

# Predict on input.
Expand All @@ -310,3 +305,43 @@ def evaluate_instance(
return True
else:
return False

def evaluate_batch(
self,
model: ModelInterface,
x_batch: np.ndarray,
y_batch: np.ndarray,
a_batch: np.ndarray,
*args,
**kwargs,
) -> List[bool]:
"""
This method performs XAI evaluation on a single batch of explanations.
For more information on the specific logic, we refer the metric’s initialisation docstring.

Parameters
----------
model: ModelInterface
A ModelInterface that is subject to explanation.
x_batch: np.ndarray
The input to be evaluated on a batch-basis.
y_batch: np.ndarray
The output to be evaluated on a batch-basis.
a_batch: np.ndarray
The explanation to be evaluated on a batch-basis.
args:
Unused.
kwargs:
Unused.

Returns
-------

scores_batch:
List of booleans.
"""

return [
self.evaluate_instance(model=model, x=x, y=y, a=a)
for x, y, a in zip(x_batch, y_batch, a_batch)
]
Loading