diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 0ac9b3d82fc7b0..3eae429abb206a 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -29,6 +29,7 @@ BitsAndBytesConfig, pipeline, ) +from transformers.models.opt.modeling_opt import OPTAttention from transformers.testing_utils import ( apply_skip_if_not_implemented, is_bitsandbytes_available, @@ -565,7 +566,7 @@ def test_training(self): # Step 2: add adapters for _, module in model.named_modules(): - if "OPTAttention" in repr(type(module)): + if isinstance(module, OPTAttention): module.q_proj = LoRALayer(module.q_proj, rank=16) module.k_proj = LoRALayer(module.k_proj, rank=16) module.v_proj = LoRALayer(module.v_proj, rank=16) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 5a99ab32e42b8c..567aa956271b70 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -29,6 +29,7 @@ BitsAndBytesConfig, pipeline, ) +from transformers.models.opt.modeling_opt import OPTAttention from transformers.testing_utils import ( apply_skip_if_not_implemented, is_accelerate_available, @@ -868,7 +869,7 @@ def test_training(self): # Step 2: add adapters for _, module in model.named_modules(): - if "OPTAttention" in repr(type(module)): + if isinstance(module, OPTAttention): module.q_proj = LoRALayer(module.q_proj, rank=16) module.k_proj = LoRALayer(module.k_proj, rank=16) module.v_proj = LoRALayer(module.v_proj, rank=16)