diff --git a/tests/test_align/test_template/test_llm.py b/tests/test_align/test_template/test_llm.py index 86182c6fd..b997ec4fb 100644 --- a/tests/test_align/test_template/test_llm.py +++ b/tests/test_align/test_template/test_llm.py @@ -96,10 +96,11 @@ def test_codegeex4(): def test_telechat(): pt_engine = PtEngine('TeleAI/TeleChat2-7B', torch_dtype=torch.float16) - messages = [{'role': 'user', 'content': '你是一个乐于助人的智能助手,请使用用户提问的语言进行有帮助的问答'}, {'role': 'user', 'content': '你好'}] - _infer_model(pt_engine, messages=messages) + messages = [{'role': 'system', 'content': '你是一个乐于助人的智能助手,请使用用户提问的语言进行有帮助的问答'}, {'role': 'user', 'content': '你好'}] + response = _infer_model(pt_engine, messages=messages) pt_engine.default_template.template_backend = 'jinja' - _infer_model(pt_engine, messages=messages) + response2 = _infer_model(pt_engine, messages=messages) + assert response == response2 def test_glm_edge():