From 731f8908877a4980dc8e42217d6fd0d4d8f01830 Mon Sep 17 00:00:00 2001 From: Kaixiong Happy <63200713+Lacacy@users.noreply.github.com> Date: Tue, 3 Dec 2024 19:00:28 +0800 Subject: [PATCH] Update tensor_parallel.py Resolve the issue of abnormal conversation performance in the Baichuan large model. --- server/text_generation_server/layers/tensor_parallel.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 13f12ef1ec1..33a5f06fc3e 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -66,6 +66,11 @@ def load(config, prefix: str, weights): weight = weights.get_tensor(f"{prefix}.weight") should_gather = False + if config.model_type == "baichuan": + # Resolve the issue of abnormal conversation performance in the Baichuan large model. + # https://github.com/huggingface/text-generation-inference/issues/2780 + weight = F.normalize(weight) + return TensorParallelHead( get_linear(weight, bias=None), process_group=weights.process_group,