diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index ee1f8fdc..cb004c6b 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -108,7 +108,7 @@ def _get_model_with_linear_top(self) -> torch.nn: def _obtain_predictions(self, x, model_predict_kwargs): pred = None - if isinstance(self.model, PreTrainedModel): + if PreTrainedModel is not None and isinstance(self.model, PreTrainedModel): # BatchEncoding is the default output from Tokenizers which contains # necessary keys such as `input_ids` and `attention_mask`. # It is also possible to pass a Dict with those keys.