Skip to content

Commit

Permalink
🚧
Browse files Browse the repository at this point in the history
  • Loading branch information
aaarrti committed Apr 28, 2024
1 parent d0a38bc commit 3022a28
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Mapping,
Optional,
Tuple,
Type,
Union,
TypeGuard,
TypedDict,
Expand All @@ -31,8 +30,7 @@
import numpy.typing as npt
import torch
import sys
from torch import clamp, nn
from torch._C import device
from torch import nn

from quantus.helpers import utils
from quantus.helpers.model.model_interface import ModelInterface
Expand Down Expand Up @@ -87,11 +85,8 @@ def _get_last_softmax_layer_index(self) -> Optional[int]:
return i
return None

last_layer = list(self.model.children())[-1]
return isinstance(last_layer, torch.nn.Softmax)

@lru_cache(maxsize=None)
def _get_model_with_linear_top(self) -> torch.nn:
def _get_model_with_linear_top(self) -> torch.nn.Module:
"""
In a case model has a softmax module, the last torch.nn.Softmax module in the self.model.modules() list is
replaced with torch.nn.Identity().
Expand All @@ -116,7 +111,11 @@ def _get_model_with_linear_top(self) -> torch.nn:

def _obtain_predictions(
self,
x: torch.Tensor | Mapping[str, torch.Tensor | npt.ArrayLike],
x: Union[
torch.Tensor,
npt.ArrayLike,
Mapping[str, Union[torch.Tensor, npt.ArrayLike]],
],
model_predict_kwargs: dict[str, Any],
) -> torch.Tensor:
if safe_isinstance(self.model, "transformers.modeling_utils.PreTrainedModel"):
Expand All @@ -133,10 +132,13 @@ def _obtain_predictions(
return torch.softmax(pred, dim=-1)
return pred

if isinstance(self.model, nn.Module):
elif isinstance(self.model, nn.Module):
pred_model = self.get_softmax_arg_model()
return pred_model(torch.Tensor(x).to(self.device), **model_predict_kwargs)
raise ValueError("Predictions cant be null")
return pred_model(
torch.Tensor(x).to(self.device), **model_predict_kwargs
)
else:
raise ValueError("Predictions cant be null")

def get_softmax_arg_model(self) -> torch.nn.Module:
"""
Expand Down Expand Up @@ -237,11 +239,11 @@ def predict(

def shape_input(
self,
x: np.array,
x: np.ndarray,
shape: Tuple[int, ...],
channel_first: Optional[bool] = None,
batched: bool = False,
) -> np.array:
) -> np.ndarray:
"""
Reshape input into model expected input.
Expand Down Expand Up @@ -569,8 +571,8 @@ def safe_isinstance(obj: Any, class_path_str: Iterable[str] | str) -> bool:


class BatchEncodingLike(TypedDict):
input_ids: torch.Tensor | npt.ArrayLike
attention_mask: torch.Tensor | npt.ArrayLike
input_ids: Union[torch.Tensor, npt.ArrayLike]
attention_mask: Union[torch.Tensor, npt.ArrayLike]


def is_batch_encoding_like(x: Any) -> TypeGuard[BatchEncodingLike]:
Expand Down

0 comments on commit 3022a28

Please sign in to comment.