-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
76d9fdc
commit bf7ab44
Showing
13 changed files
with
7,051 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
"""This module implements the base class for creating evaluation metrics.""" | ||
# This file is part of Quantus. | ||
# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. | ||
# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. | ||
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>. | ||
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>. | ||
|
||
import inspect | ||
import re | ||
from abc import abstractmethod | ||
from collections.abc import Sequence | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Dict, | ||
Sequence, | ||
Optional, | ||
Tuple, | ||
Union, | ||
Collection, | ||
List, | ||
) | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from tqdm.auto import tqdm | ||
|
||
from quantus.helpers import asserts | ||
from quantus.helpers import utils | ||
from quantus.helpers import warn | ||
from quantus.helpers.model.model_interface import ModelInterface | ||
from quantus.metrics.base import Metric | ||
from quantus.helpers.enums import ( | ||
ModelType, | ||
DataType, | ||
ScoreDirection, | ||
EvaluationCategory, | ||
) | ||
|
||
|
||
class PerturbationMetric(Metric): | ||
""" | ||
Implementation base PertubationMetric class. | ||
Metric categories such as Faithfulness and Robustness share certain characteristics when it comes to perturbations. | ||
As follows, this metric class is created which has additional attributes for perturbations. | ||
Attributes: | ||
- name: The name of the metric. | ||
- data_applicability: The data types that the metric implementation currently supports. | ||
- model_applicability: The model types that this metric can work with. | ||
- score_direction: How to interpret the scores, whether higher/ lower values are considered better. | ||
- evaluation_category: What property/ explanation quality that this metric measures. | ||
""" | ||
|
||
name = "PerturbationMetric" | ||
data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR} | ||
model_applicability = {ModelType.TORCH, ModelType.TF} | ||
score_direction = ScoreDirection.HIGHER | ||
evaluation_category = EvaluationCategory.NONE | ||
|
||
@asserts.attributes_check | ||
def __init__( | ||
self, | ||
abs: bool, | ||
normalise: bool, | ||
normalise_func: Callable, | ||
normalise_func_kwargs: Optional[Dict[str, Any]], | ||
perturb_func: Callable, | ||
perturb_func_kwargs: Optional[Dict[str, Any]], | ||
return_aggregate: bool, | ||
aggregate_func: Callable, | ||
default_plot_func: Optional[Callable], | ||
disable_warnings: bool, | ||
display_progressbar: bool, | ||
**kwargs, | ||
): | ||
""" | ||
Initialise the PerturbationMetric base class. | ||
Parameters | ||
---------- | ||
Parameters | ||
---------- | ||
abs: boolean | ||
Indicates whether absolute operation is applied on the attribution. | ||
normalise: boolean | ||
Indicates whether normalise operation is applied on the attribution. | ||
normalise_func: callable | ||
Attribution normalisation function applied in case normalise=True. | ||
normalise_func_kwargs: dict | ||
Keyword arguments to be passed to normalise_func on call. | ||
perturb_func: callable | ||
Input perturbation function. | ||
perturb_func_kwargs: dict, optional | ||
Keyword arguments to be passed to perturb_func. | ||
return_aggregate: boolean | ||
Indicates if an aggregated score should be computed over all instances. | ||
aggregate_func: callable | ||
Callable that aggregates the scores given an evaluation call. | ||
default_plot_func: callable | ||
Callable that plots the metrics result. | ||
disable_warnings: boolean | ||
Indicates whether the warnings are printed. | ||
display_progressbar: boolean | ||
Indicates whether a tqdm-progress-bar is printed. | ||
kwargs: optional | ||
Keyword arguments. | ||
""" | ||
|
||
# Initialize super-class with passed parameters | ||
super().__init__( | ||
abs=abs, | ||
normalise=normalise, | ||
normalise_func=normalise_func, | ||
normalise_func_kwargs=normalise_func_kwargs, | ||
return_aggregate=return_aggregate, | ||
aggregate_func=aggregate_func, | ||
default_plot_func=default_plot_func, | ||
display_progressbar=display_progressbar, | ||
disable_warnings=disable_warnings, | ||
**kwargs, | ||
) | ||
|
||
# Save perturbation metric attributes. | ||
self.perturb_func = perturb_func | ||
|
||
if perturb_func_kwargs is None: | ||
perturb_func_kwargs = {} | ||
self.perturb_func_kwargs = perturb_func_kwargs | ||
|
||
@abstractmethod | ||
def evaluate_instance( | ||
self, | ||
model: ModelInterface, | ||
x: np.ndarray, | ||
y: Optional[np.ndarray], | ||
a: Optional[np.ndarray], | ||
s: Optional[np.ndarray], | ||
) -> Any: | ||
""" | ||
Evaluate instance gets model and data for a single instance as input and returns the evaluation result. | ||
This method needs to be implemented to use __call__(). | ||
Parameters | ||
---------- | ||
model: ModelInterface | ||
A ModelInteface that is subject to explanation. | ||
x: np.ndarray | ||
The input to be evaluated on an instance-basis. | ||
y: np.ndarray | ||
The output to be evaluated on an instance-basis. | ||
a: np.ndarray | ||
The explanation to be evaluated on an instance-basis. | ||
s: np.ndarray | ||
The segmentation to be evaluated on an instance-basis. | ||
Returns | ||
------- | ||
Any | ||
""" | ||
raise NotImplementedError() |
Oops, something went wrong.