Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Jan 10, 2024
1 parent 45b5ee7 commit d68ad5b
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions tests/models/chatglm/test_modeling_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from transformers import ChatGlmConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_torch,
slow,
require_torch_gpu,
slow,
torch_device,
)

Expand All @@ -36,10 +36,10 @@
import torch

from transformers import (
AutoTokenizer,
ChatGlmForCausalLM,
ChatGlmForSequenceClassification,
ChatGlmModel,
AutoTokenizer,
)


Expand Down Expand Up @@ -407,8 +407,8 @@ def test_chat_glm_generation_6b(self):
model_id = "ybelkada/chatglm3-6b-hf"

model = ChatGlmForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)

Expand All @@ -420,14 +420,13 @@ def test_chat_glm_generation_6b(self):

self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_TEXT)


def test_chat_glm_generation_6b_batched(self):
# TODO: change to THUDM/chatglm3-6b in the future
model_id = "ybelkada/chatglm3-6b-hf"

model = ChatGlmForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)

Expand All @@ -436,5 +435,5 @@ def test_chat_glm_generation_6b_batched(self):
EXPECTED_TEXT = "[gMASK]sop 你好,我是人工智能助手。很高兴认识你叫什么"
inputs = tokenizer(["你好", "你"], return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=10)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_TEXT)

self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_TEXT)

0 comments on commit d68ad5b

Please sign in to comment.