diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 13f12ef1ec1..33a5f06fc3e 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -66,6 +66,11 @@ def load(config, prefix: str, weights): weight = weights.get_tensor(f"{prefix}.weight") should_gather = False + if config.model_type == "baichuan": + # Resolve the issue of abnormal conversation performance in the Baichuan large model. + # https://github.com/huggingface/text-generation-inference/issues/2780 + weight = F.normalize(weight) + return TensorParallelHead( get_linear(weight, bias=None), process_group=weights.process_group,