From 9333b586bcc7aada0b326eeb9a7f81d785e2da41 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 25 Jul 2024 11:04:24 +0200 Subject: [PATCH] fix bt bark test --- tests/bettertransformer/testing_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 6e7ff71ddd9..f1e3ca16361 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -27,7 +27,7 @@ MODELS_DICT = { "albert": "hf-internal-testing/tiny-random-AlbertModel", - "bark": "ylacombe/bark-small", # TODO: put a smaller model, this one is 1.7GB... + "bark": "hf-internal-testing/tiny-random-BarkModel", "bart": "hf-internal-testing/tiny-random-bart", "bert": "hf-internal-testing/tiny-random-BertModel", "bert-generation": "ybelkada/random-tiny-BertGenerationModel", @@ -359,7 +359,8 @@ def _test_save_load_invertible(self, model_id, keep_original_model=True): for name, param in bt_model.named_parameters(): self.assertFalse(param.device.type == "meta", f"Parameter {name} is on the meta device.") - bt_model.save_pretrained(tmpdirname) + # saving a normal transformers bark model fails because of shared tensors + bt_model.save_pretrained(tmpdirname, safe_serialization=hf_model.config.model_type != "bark") bt_model_from_load = AutoModel.from_pretrained(tmpdirname)