From d68ad5b6843235d9799b9b913d15fe6aa755f341 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 10 Jan 2024 11:25:50 +0000 Subject: [PATCH] format --- tests/models/chatglm/test_modeling_chatglm.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/models/chatglm/test_modeling_chatglm.py b/tests/models/chatglm/test_modeling_chatglm.py index 7ddaa85c887f51..5ab1ccee3a62cc 100644 --- a/tests/models/chatglm/test_modeling_chatglm.py +++ b/tests/models/chatglm/test_modeling_chatglm.py @@ -21,8 +21,8 @@ from transformers import ChatGlmConfig, is_torch_available, set_seed from transformers.testing_utils import ( require_torch, - slow, require_torch_gpu, + slow, torch_device, ) @@ -36,10 +36,10 @@ import torch from transformers import ( + AutoTokenizer, ChatGlmForCausalLM, ChatGlmForSequenceClassification, ChatGlmModel, - AutoTokenizer, ) @@ -407,8 +407,8 @@ def test_chat_glm_generation_6b(self): model_id = "ybelkada/chatglm3-6b-hf" model = ChatGlmForCausalLM.from_pretrained( - model_id, - torch_dtype=torch.float16, + model_id, + torch_dtype=torch.float16, low_cpu_mem_usage=True, ).to(torch_device) @@ -420,14 +420,13 @@ def test_chat_glm_generation_6b(self): self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_TEXT) - def test_chat_glm_generation_6b_batched(self): # TODO: change to THUDM/chatglm3-6b in the future model_id = "ybelkada/chatglm3-6b-hf" model = ChatGlmForCausalLM.from_pretrained( - model_id, - torch_dtype=torch.float16, + model_id, + torch_dtype=torch.float16, low_cpu_mem_usage=True, ).to(torch_device) @@ -436,5 +435,5 @@ def test_chat_glm_generation_6b_batched(self): EXPECTED_TEXT = "[gMASK]sop 你好,我是人工智能助手。很高兴认识你叫什么" inputs = tokenizer(["你好", "你"], return_tensors="pt").to(torch_device) output = model.generate(**inputs, max_new_tokens=10) - - self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_TEXT) \ No newline at end of file + + self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_TEXT)