From 4f201e8eb3a7375a6e94f4db3fcb1f9cad1a4629 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 7 Dec 2023 13:05:09 +0100 Subject: [PATCH] correct the tests --- tests/models/chatglm/test_modeling_chatglm.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/models/chatglm/test_modeling_chatglm.py b/tests/models/chatglm/test_modeling_chatglm.py index 0f0c0443f558c6..07914b86b94a6e 100644 --- a/tests/models/chatglm/test_modeling_chatglm.py +++ b/tests/models/chatglm/test_modeling_chatglm.py @@ -24,7 +24,6 @@ require_bitsandbytes, require_flash_attn, require_torch, - require_torch_accelerator, require_torch_gpu, slow, torch_device, @@ -43,11 +42,11 @@ ChatGlmForCausalLM, ChatGlmForSequenceClassification, ChatGlmModel, - CodeLlamaTokenizer, LlamaTokenizer, ) +# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->ChatGlm class ChatGlmModelTester: def __init__( self, @@ -275,11 +274,22 @@ def prepare_config_and_inputs_for_common(self): @require_torch +# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest with Llama->ChatGlm class ChatGlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( (ChatGlmModel, ChatGlmForCausalLM, ChatGlmForSequenceClassification) if is_torch_available() else () ) all_generative_model_classes = (ChatGlmForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": ChatGlmModel, + "text-classification": ChatGlmForSequenceClassification, + "text-generation": ChatGlmForCausalLM, + "zero-shot": ChatGlmForSequenceClassification, + } + if is_torch_available() + else {} + ) test_headmasking = False test_pruning = False @@ -300,7 +310,7 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - def test_chatglm_sequence_classification_model(self): + def test_llama_sequence_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 input_ids = input_dict["input_ids"] @@ -312,7 +322,7 @@ def test_chatglm_sequence_classification_model(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - def test_chatglm_sequence_classification_model_for_single_label(self): + def test_llama_sequence_classification_model_for_single_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 config.problem_type = "single_label_classification" @@ -325,7 +335,7 @@ def test_chatglm_sequence_classification_model_for_single_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - def test_chatglm_sequence_classification_model_for_multi_label(self): + def test_llama_sequence_classification_model_for_multi_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 config.problem_type = "multi_label_classification" @@ -385,12 +395,12 @@ def test_flash_attn_2_generate_padding_right(self): Overwritting the common test as the test is flaky on tiny models """ model = ChatGlmForCausalLM.from_pretrained( - "meta-chatglm/ChatGlm-2-7b-hf", + "meta-llama/ChatGlm-2-7b-hf", load_in_4bit=True, device_map={"": 0}, ) - tokenizer = LlamaTokenizer.from_pretrained("meta-chatglm/ChatGlm-2-7b-hf") + tokenizer = ChatGlmTokenizer.from_pretrained("meta-llama/ChatGlm-2-7b-hf") texts = ["hi", "Hello this is a very long sentence"] @@ -403,7 +413,7 @@ def test_flash_attn_2_generate_padding_right(self): output_native = tokenizer.batch_decode(output_native) model = ChatGlmForCausalLM.from_pretrained( - "meta-chatglm/ChatGlm-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True + "meta-llama/ChatGlm-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True ) output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)