From 8af890a865bf9f744d7d9bd5515558b42224c744 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Tue, 26 Mar 2024 09:09:31 +0800 Subject: [PATCH] Enable more models to inference based on LoRA (#3382) Co-authored-by: Antoni Baum --- csrc/punica/bgmv/bgmv_config.h | 6 + tests/lora/conftest.py | 10 ++ tests/lora/test_baichuan.py | 108 ++++++++++++++++ tests/lora/test_chatglm3.py | 57 +++++++++ tests/lora/test_layers.py | 13 +- tests/lora/test_punica.py | 7 +- vllm/lora/layers.py | 168 +++++++++++++++++++++---- vllm/lora/models.py | 11 +- vllm/model_executor/models/baichuan.py | 52 ++++++-- vllm/model_executor/models/chatglm.py | 15 +++ 10 files changed, 402 insertions(+), 45 deletions(-) create mode 100644 tests/lora/test_baichuan.py create mode 100644 tests/lora/test_chatglm3.py diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index a7415dfc91369..2219d960ae62f 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -16,10 +16,13 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 512) \ f(in_T, out_T, W_T, narrow, 768) \ f(in_T, out_T, W_T, narrow, 1024) \ + f(in_T, out_T, W_T, narrow, 1152) \ f(in_T, out_T, W_T, narrow, 1280) \ + f(in_T, out_T, W_T, narrow, 1536) \ f(in_T, out_T, W_T, narrow, 1728) \ f(in_T, out_T, W_T, narrow, 1792) \ f(in_T, out_T, W_T, narrow, 2048) \ + f(in_T, out_T, W_T, narrow, 2304) \ f(in_T, out_T, W_T, narrow, 2560) \ f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2816) \ @@ -27,10 +30,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 4096) \ + f(in_T, out_T, W_T, narrow, 4608) \ f(in_T, out_T, W_T, narrow, 5120) \ f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 6144) \ + f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 8192) \ @@ -45,6 +50,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 24576) \ + f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 0705a51ca2cff..acb5fa91e2012 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -134,6 +134,16 @@ def gemma_lora_files(): return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") +@pytest.fixture(scope="session") +def chatglm3_lora_files(): + return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider") + + +@pytest.fixture(scope="session") +def baichuan_lora_files(): + return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py new file mode 100644 index 0000000000000..2178266d2e0c8 --- /dev/null +++ b/tests/lora/test_baichuan.py @@ -0,0 +1,108 @@ +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +from .conftest import cleanup + +MODEL_PATH = "baichuan-inc/Baichuan-7B" + +PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 + + +def do_sample(llm, lora_path: str, lora_id: int) -> str: + prompts = [ + PROMPT_TEMPLATE.format(query="How many singers do we have?"), + PROMPT_TEMPLATE.format( + query= + "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + query= + "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 + ), + ] + print(prompts) + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_baichuan_lora(baichuan_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True) + + expected_lora_output = [ + "SELECT count(*) FROM singer", + "SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501 + "SELECT name , country , age FROM singer ORDER BY age ASC", + ] + + output1 = do_sample(llm, baichuan_lora_files, lora_id=1) + for i in range(len(expected_lora_output)): + assert output1[i] == expected_lora_output[i] + output2 = do_sample(llm, baichuan_lora_files, lora_id=2) + for i in range(len(expected_lora_output)): + assert output2[i] == expected_lora_output[i] + + +@pytest.mark.skip("Requires multiple GPUs") +def test_llama_tensor_parallel_equality(baichuan_lora_files): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < 4: + # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") + + llm_tp1 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=1, + trust_remote_code=True) + output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) + + del llm_tp1 + cleanup() + + llm_tp2 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=2, + trust_remote_code=True) + output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) + + del llm_tp2 + cleanup() + + assert output_tp1 == output_tp2 + + llm_tp4 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True) + output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) + + del llm_tp4 + cleanup() + + assert output_tp1 == output_tp4 \ No newline at end of file diff --git a/tests/lora/test_chatglm3.py b/tests/lora/test_chatglm3.py new file mode 100644 index 0000000000000..bd8cc98ef8ca0 --- /dev/null +++ b/tests/lora/test_chatglm3.py @@ -0,0 +1,57 @@ +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "THUDM/chatglm3-6b" + +PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 + + +def do_sample(llm, lora_path: str, lora_id: int) -> str: + prompts = [ + PROMPT_TEMPLATE.format(query="How many singers do we have?"), + PROMPT_TEMPLATE.format( + query= + "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + query= + "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 + ), + ] + print(prompts) + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_chatglm3_lora(chatglm3_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True) + + expected_lora_output = [ + "SELECT count(*) FROM singer", + "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 + "SELECT name , country , age FROM singer ORDER BY age", + ] + + output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) + for i in range(len(expected_lora_output)): + assert output1[i] == expected_lora_output[i] + output2 = do_sample(llm, chatglm3_lora_files, lora_id=2) + for i in range(len(expected_lora_output)): + assert output2[i] == expected_lora_output[i] diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index caaece883ba21..71ce6f1764832 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,12 +8,16 @@ import torch.nn.functional as F from vllm.config import LoRAConfig +# yapf conflicts with isort for this block +# yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, LogitsProcessorWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) +# yapf: enable from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights, convert_mapping) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -93,8 +97,7 @@ def populate_loras( lora_dict: Dict[int, LoRALayerWeights] = dict() # Dictionary that maps the lora ID to the - # corresponding subloras. Only useful when - # repeats > 1. + # corresponding subloras. sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() for slot_idx, lora_id in enumerate(id_to_index): @@ -607,7 +610,7 @@ def create_random_linear_parallel_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("repeats", [2, 3]) +@pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: @@ -623,6 +626,10 @@ def create_column_parallel_packed_layer(): bias=False) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = MergedColumnParallelLinearWithLoRA(linear) + elif repeats == 3: + linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = MergedQKVParallelLinearWithLora(linear) else: linear = QKVParallelLinear(4096, 64, 32, bias=False) linear.weight.data = torch.rand_like(linear.weight.data) diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index fd707766c6a30..6e05697f0475f 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -43,9 +43,10 @@ def _lora_ref_impl( H1 = H2 = [ - 128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120, - 5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, - 22016, 24576, 32000, 32256, 32512, 32768, 33024 + 128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456, + 3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216, + 10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512, + 32768, 33024 ] SEED = [0xabcdabcd987] diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 920523e58ccfc..0505014753951 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,7 +1,8 @@ # pylint: disable=unused-argument +import inspect import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type import torch import torch.nn as nn @@ -114,8 +115,11 @@ def __post_init__(self): class BaseLayerWithLoRA(nn.Module): - def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, - model_config: PretrainedConfig) -> None: + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: """Initializes lora matrices.""" ... @@ -144,6 +148,13 @@ def set_mapping( """Sets the mapping indices.""" ... + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError + class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): @@ -278,12 +289,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.indices[:self.indices_len[0]], 0, 1.0) return full_output.view_as(full_output_org) + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is VocabParallelEmbedding + class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() self.base_layer = base_layer + self.tp_size = get_tensor_model_parallel_world_size() def create_lora_weights( self, @@ -309,7 +327,7 @@ def create_lora_weights( self.indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None - self.output_dim = self.lora_b_stacked.shape[1] + self.output_dim = self.lora_b_stacked.shape[2] def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -323,7 +341,12 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) - + if self.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) @@ -383,6 +406,14 @@ def forward(self, input_): def linear_weights(self): return self.base_layer.linear_weights + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 2 sublayers (slices) @@ -485,8 +516,80 @@ def apply_weights(self, x: torch.Tensor, ) return output + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is MergedColumnParallelLinear and len( + packed_modules_list) == 2 + class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chtglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.tp_size = get_tensor_model_parallel_world_size() + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + if self.tp_size > 1: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size * + self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size * + self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj). @@ -654,6 +757,13 @@ def apply_weights(self, x: torch.Tensor, ) return output + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 3 + class RowParallelLinearWithLoRA(BaseLayerWithLoRA): @@ -780,6 +890,12 @@ def forward(self, input_): def weight(self): return self.base_layer.weight + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is RowParallelLinear + class LogitsProcessorWithLoRA(BaseLayerWithLoRA): @@ -900,7 +1016,7 @@ def _get_logits( hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: @@ -949,22 +1065,30 @@ def _get_logits( def forward(self, *args, **kwargs): return type(self.base_layer).forward(self, *args, **kwargs) - -def from_layer( - layer: nn.Module, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: - supported_layer_types = { - VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, - ColumnParallelLinear: ColumnParallelLinearWithLoRA, - QKVParallelLinear: QKVParallelLinearWithLora, - MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, - RowParallelLinear: RowParallelLinearWithLoRA, - } - for src_layer_type, lora_layer_type in supported_layer_types.items(): - if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck - ret = lora_layer_type(layer) + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # Special handling for the LogitsProcessor. + return False + + +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + cls + for cls in globals().values() if inspect.isclass(cls) + and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list, + model_config): + ret = lora_cls(layer) ret.create_lora_weights(max_loras, lora_config, model_config) return ret return layer diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 97ee1a78b20b7..ddbdd50d0e154 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -413,11 +413,12 @@ def _create_lora_modules(self): for module_name, module in self.model.named_modules(): if not self._match_target_modules(module_name): continue - + parts = module_name.split(".")[-1] + packed_moduled_lst = self.packed_modules_mapping.get(parts, []) new_module = replace_submodule( self.model, module_name, from_layer(module, self.lora_slots, self.lora_config, - self.model.config)) + packed_moduled_lst, self.model.config)) # (yard1): TODO make this more robust if "lm_head" in module_name: logits_processor_module = self.model.get_submodule( @@ -510,8 +511,10 @@ def _match_target_modules(self, module_name: str): def _register_packed_modules(self, module_full_name: str) -> None: parts = module_full_name.split(".") module_name = parts[-1] - replacements = self.packed_modules_mapping.get(module_name) - if not replacements: + replacements = self.packed_modules_mapping.get(module_name, []) + # When replacements is less than or equal to 1, it indicates that this + # module is not a packed module. + if len(replacements) <= 1: return prefix = ".".join(parts[:-1]) self.packed_modules[module_full_name] = [ diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 4ecafa726321d..fa5a27b5a6974 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -26,6 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import LoRAConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -282,11 +283,30 @@ def forward( class BaiChuanBaseForCausalLM(nn.Module): + packed_modules_mapping = { + "W_pack": ["W_pack"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + "W_pack", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] - def __init__(self, - config, - position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + def __init__( + self, + config, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ): super().__init__() self.config = config self.linear_method = linear_method @@ -371,19 +391,25 @@ def load_weights(self, class BaichuanForCausalLM(BaiChuanBaseForCausalLM): """Baichuan 13B and Baichuan2 7B/13B.""" - def __init__(self, - config, - linear_method: Optional[LinearMethodBase] = None): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", linear_method) + super().__init__(config, "ROPE", linear_method, lora_config) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", linear_method) + super().__init__(config, "ALIBI", linear_method, lora_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): """Baichuan 7B.""" - def __init__(self, - config, - linear_method: Optional[LinearMethodBase] = None): - super().__init__(config, "ROPE", linear_method) + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__(config, "ROPE", linear_method, lora_config) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 15905e2250832..4008896e48dd1 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -9,6 +9,7 @@ from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata +from vllm.config import LoRAConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -317,11 +318,25 @@ def forward( class ChatGLMForCausalLM(nn.Module): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + ] + embedding_modules = {} + embedding_padding_modules = [] def __init__( self, config: ChatGLMConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config: ChatGLMConfig = config