Skip to content

Commit

Permalink
merge fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem committed Mar 25, 2024
1 parent 76d9fdc commit bf7ab44
Show file tree
Hide file tree
Showing 13 changed files with 7,051 additions and 1 deletion.
2 changes: 1 addition & 1 deletion quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def f_reduce_axes(a):
inputs = inputs.cpu()

inputs_numpy = inputs.detach().numpy()

for i in range(len(explanation)):
explanation[i] = torch.Tensor(
np.clip(scipy.ndimage.sobel(inputs_numpy[i]), 0, 1)
Expand Down
32 changes: 32 additions & 0 deletions quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,38 @@ def sample(
)
return model_copy

def perturb_layer_weights(self, layer_idx: int, noise: float):
"""
Perturb the weights of a specific layer in a PyTorch model.
Parameters
----------
model : torch.nn.Module
The PyTorch model.
layer_idx : int
The index of the layer to perturb.
noise : float
The standard deviation of the Gaussian noise to add to the weights.
Returns
-------
None
"""
original_parameters = self.state_dict()
model_copy = deepcopy(self.model)
model_copy.load_state_dict(original_parameters)

# Get the specific layer.
layer = list(model_copy.modules())[layer_idx]

# Generate Gaussian noise.
noise_tensor = torch.randn_like(layer.weight) * noise

# Add the noise to the layer's weights.
layer.weight.data.add_(noise_tensor)

return model_copy

def add_mean_shift_to_first_layer(
self,
input_shift: Union[int, float],
Expand Down
162 changes: 162 additions & 0 deletions quantus/metrics/base_perturbed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""This module implements the base class for creating evaluation metrics."""
# This file is part of Quantus.
# Quantus is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see <https://www.gnu.org/licenses/>.
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

import inspect
import re
from abc import abstractmethod
from collections.abc import Sequence
from typing import (
Any,
Callable,
Dict,
Sequence,
Optional,
Tuple,
Union,
Collection,
List,
)
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm

from quantus.helpers import asserts
from quantus.helpers import utils
from quantus.helpers import warn
from quantus.helpers.model.model_interface import ModelInterface
from quantus.metrics.base import Metric
from quantus.helpers.enums import (
ModelType,
DataType,
ScoreDirection,
EvaluationCategory,
)


class PerturbationMetric(Metric):
"""
Implementation base PertubationMetric class.
Metric categories such as Faithfulness and Robustness share certain characteristics when it comes to perturbations.
As follows, this metric class is created which has additional attributes for perturbations.
Attributes:
- name: The name of the metric.
- data_applicability: The data types that the metric implementation currently supports.
- model_applicability: The model types that this metric can work with.
- score_direction: How to interpret the scores, whether higher/ lower values are considered better.
- evaluation_category: What property/ explanation quality that this metric measures.
"""

name = "PerturbationMetric"
data_applicability = {DataType.IMAGE, DataType.TIMESERIES, DataType.TABULAR}
model_applicability = {ModelType.TORCH, ModelType.TF}
score_direction = ScoreDirection.HIGHER
evaluation_category = EvaluationCategory.NONE

@asserts.attributes_check
def __init__(
self,
abs: bool,
normalise: bool,
normalise_func: Callable,
normalise_func_kwargs: Optional[Dict[str, Any]],
perturb_func: Callable,
perturb_func_kwargs: Optional[Dict[str, Any]],
return_aggregate: bool,
aggregate_func: Callable,
default_plot_func: Optional[Callable],
disable_warnings: bool,
display_progressbar: bool,
**kwargs,
):
"""
Initialise the PerturbationMetric base class.
Parameters
----------
Parameters
----------
abs: boolean
Indicates whether absolute operation is applied on the attribution.
normalise: boolean
Indicates whether normalise operation is applied on the attribution.
normalise_func: callable
Attribution normalisation function applied in case normalise=True.
normalise_func_kwargs: dict
Keyword arguments to be passed to normalise_func on call.
perturb_func: callable
Input perturbation function.
perturb_func_kwargs: dict, optional
Keyword arguments to be passed to perturb_func.
return_aggregate: boolean
Indicates if an aggregated score should be computed over all instances.
aggregate_func: callable
Callable that aggregates the scores given an evaluation call.
default_plot_func: callable
Callable that plots the metrics result.
disable_warnings: boolean
Indicates whether the warnings are printed.
display_progressbar: boolean
Indicates whether a tqdm-progress-bar is printed.
kwargs: optional
Keyword arguments.
"""

# Initialize super-class with passed parameters
super().__init__(
abs=abs,
normalise=normalise,
normalise_func=normalise_func,
normalise_func_kwargs=normalise_func_kwargs,
return_aggregate=return_aggregate,
aggregate_func=aggregate_func,
default_plot_func=default_plot_func,
display_progressbar=display_progressbar,
disable_warnings=disable_warnings,
**kwargs,
)

# Save perturbation metric attributes.
self.perturb_func = perturb_func

if perturb_func_kwargs is None:
perturb_func_kwargs = {}
self.perturb_func_kwargs = perturb_func_kwargs

@abstractmethod
def evaluate_instance(
self,
model: ModelInterface,
x: np.ndarray,
y: Optional[np.ndarray],
a: Optional[np.ndarray],
s: Optional[np.ndarray],
) -> Any:
"""
Evaluate instance gets model and data for a single instance as input and returns the evaluation result.
This method needs to be implemented to use __call__().
Parameters
----------
model: ModelInterface
A ModelInteface that is subject to explanation.
x: np.ndarray
The input to be evaluated on an instance-basis.
y: np.ndarray
The output to be evaluated on an instance-basis.
a: np.ndarray
The explanation to be evaluated on an instance-basis.
s: np.ndarray
The segmentation to be evaluated on an instance-basis.
Returns
-------
Any
"""
raise NotImplementedError()
Loading

0 comments on commit bf7ab44

Please sign in to comment.