Skip to content

Commit

Permalink
correct the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 7, 2023
1 parent 3bb7687 commit 4f201e8
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions tests/models/chatglm/test_modeling_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_accelerator,
require_torch_gpu,
slow,
torch_device,
Expand All @@ -43,11 +42,11 @@
ChatGlmForCausalLM,
ChatGlmForSequenceClassification,
ChatGlmModel,
CodeLlamaTokenizer,
LlamaTokenizer,
)


# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->ChatGlm
class ChatGlmModelTester:
def __init__(
self,
Expand Down Expand Up @@ -275,11 +274,22 @@ def prepare_config_and_inputs_for_common(self):


@require_torch
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest with Llama->ChatGlm
class ChatGlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(ChatGlmModel, ChatGlmForCausalLM, ChatGlmForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (ChatGlmForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": ChatGlmModel,
"text-classification": ChatGlmForSequenceClassification,
"text-generation": ChatGlmForCausalLM,
"zero-shot": ChatGlmForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False

Expand All @@ -300,7 +310,7 @@ def test_model_various_embeddings(self):
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)

def test_chatglm_sequence_classification_model(self):
def test_llama_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
Expand All @@ -312,7 +322,7 @@ def test_chatglm_sequence_classification_model(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

def test_chatglm_sequence_classification_model_for_single_label(self):
def test_llama_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "single_label_classification"
Expand All @@ -325,7 +335,7 @@ def test_chatglm_sequence_classification_model_for_single_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

def test_chatglm_sequence_classification_model_for_multi_label(self):
def test_llama_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "multi_label_classification"
Expand Down Expand Up @@ -385,12 +395,12 @@ def test_flash_attn_2_generate_padding_right(self):
Overwritting the common test as the test is flaky on tiny models
"""
model = ChatGlmForCausalLM.from_pretrained(
"meta-chatglm/ChatGlm-2-7b-hf",
"meta-llama/ChatGlm-2-7b-hf",
load_in_4bit=True,
device_map={"": 0},
)

tokenizer = LlamaTokenizer.from_pretrained("meta-chatglm/ChatGlm-2-7b-hf")
tokenizer = ChatGlmTokenizer.from_pretrained("meta-llama/ChatGlm-2-7b-hf")

texts = ["hi", "Hello this is a very long sentence"]

Expand All @@ -403,7 +413,7 @@ def test_flash_attn_2_generate_padding_right(self):
output_native = tokenizer.batch_decode(output_native)

model = ChatGlmForCausalLM.from_pretrained(
"meta-chatglm/ChatGlm-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
"meta-llama/ChatGlm-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
)

output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
Expand Down

0 comments on commit 4f201e8

Please sign in to comment.