diff --git a/foolbox/models/base.py b/foolbox/models/base.py index 6a207bb2..3336aef3 100644 --- a/foolbox/models/base.py +++ b/foolbox/models/base.py @@ -66,7 +66,7 @@ def _preprocess(self, inputs: ep.TensorType) -> ep.TensorType: @property def data_format(self) -> Any: - return getattr(self._model, "data_format", None) + return self._model.data_format # type: ignore ModelType = TypeVar("ModelType", bound="ModelWithPreprocessing") diff --git a/tests/test_attacks_base.py b/tests/test_attacks_base.py index 41bc085a..16e6fead 100644 --- a/tests/test_attacks_base.py +++ b/tests/test_attacks_base.py @@ -45,21 +45,3 @@ class Model: model.data_format = "invalid" # type: ignore with pytest.raises(ValueError): assert fbn.attacks.base.get_channel_axis(model, 3) # type: ignore - - -def test_transform_bounds_wrapper_data_format() -> None: - class Model(fbn.models.Model): - data_format = "channels_first" - - @property - def bounds(self) -> fbn.types.Bounds: - return fbn.types.Bounds(0, 1) - - def __call__(self, inputs: fbn.models.base.T) -> fbn.models.base.T: - return inputs - - model = Model() - wrapped_model = fbn.models.TransformBoundsWrapper(model, (0, 1)) - assert fbn.attacks.base.get_channel_axis( - model, 3 - ) == fbn.attacks.base.get_channel_axis(wrapped_model, 3) diff --git a/tests/test_models.py b/tests/test_models.py index 49f56736..e41a0b91 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -234,3 +234,38 @@ def test_preprocessing(fmodel_and_data: ModelAndData) -> None: fmodel = fbn.models.base.ModelWithPreprocessing( fmodel._model, fmodel.bounds, fmodel.dummy, preprocessing ) + + +def test_transform_bounds_wrapper_data_format() -> None: + class Model(fbn.models.Model): + data_format = "channels_first" + + @property + def bounds(self) -> fbn.types.Bounds: + return fbn.types.Bounds(0, 1) + + def __call__(self, inputs: fbn.models.base.T) -> fbn.models.base.T: + return inputs + + model = Model() + wrapped_model = fbn.models.TransformBoundsWrapper(model, (0, 1)) + assert fbn.attacks.base.get_channel_axis( + model, 3 + ) == fbn.attacks.base.get_channel_axis(wrapped_model, 3) + assert hasattr(wrapped_model, "data_format") + assert not hasattr(wrapped_model, "not_data_format") + + +def test_transform_bounds_wrapper_missing_data_format() -> None: + class Model(fbn.models.Model): + @property + def bounds(self) -> fbn.types.Bounds: + return fbn.types.Bounds(0, 1) + + def __call__(self, inputs: fbn.models.base.T) -> fbn.models.base.T: + return inputs + + model = Model() + wrapped_model = fbn.models.TransformBoundsWrapper(model, (0, 1)) + assert not hasattr(wrapped_model, "data_format") + assert not hasattr(wrapped_model, "not_data_format")