Skip to content

Commit

Permalink
[Misc][LoRA] Clean up the function interface of Punica (#10917)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Dec 5, 2024
1 parent 39c89e7 commit 571da8f
Show file tree
Hide file tree
Showing 5 changed files with 497 additions and 631 deletions.
42 changes: 32 additions & 10 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,15 +565,18 @@ def _pretest():
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
bias_enabled=bias_enabled)

def create_random_linear_replicated_layer():

Expand All @@ -585,7 +588,12 @@ def create_random_linear_replicated_layer():
lora_linear = ReplicatedLinearWithLoRA(linear)

lora_linear.create_lora_weights(max_loras, lora_config)

assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == 1)
if bias_enabled:
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
else:
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear

for i in range(10):
Expand Down Expand Up @@ -669,8 +677,9 @@ def create_random_linear_replicated_layer():
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage) -> None:
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
Expand All @@ -679,7 +688,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
bias_enabled=bias_enabled)

def create_random_linear_parallel_layer():
if orientation == "row":
Expand All @@ -700,7 +710,12 @@ def create_random_linear_parallel_layer():
if not fully_shard else
ColumnParallelLinearWithShardedLoRA(linear))
lora_linear.create_lora_weights(max_loras, lora_config)

assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == 1)
if bias_enabled:
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
else:
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear

for i in range(10):
Expand Down Expand Up @@ -784,8 +799,9 @@ def create_random_linear_parallel_layer():
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage) -> None:
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
torch.set_default_device(device)
Expand All @@ -794,7 +810,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
lora_dtype=torch.float16,
bias_enabled=bias_enabled)

def create_column_parallel_packed_layer():
if repeats == 2:
Expand Down Expand Up @@ -832,10 +849,16 @@ class FakeConfig:
num_key_value_heads = 32
num_attention_heads = 32

n_slices = repeats
lora_linear.create_lora_weights(max_loras,
lora_config,
model_config=FakeConfig())

assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == n_slices)
if bias_enabled:
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
else:
assert lora_linear.lora_bias_stacked is None
return linear, lora_linear

for i in range(10):
Expand Down Expand Up @@ -911,7 +934,6 @@ class FakeConfig:
512,
lora_config.lora_extra_vocab_size,
)
# lora_linear.set_mapping(*mapping_info)

lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]
Expand Down
175 changes: 75 additions & 100 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -32,6 +32,44 @@ def dec(*args, **kwargs):
return dec


def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
"""
For `ColumnParallelLinearWithLoRA` or classes that inherit from
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
"""
assert (layer.n_slices == len(layer.lora_a_stacked) == len(
layer.lora_b_stacked) == len(layer.output_slices))
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)

output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape

# Since communication is needed, the buffer is directly initialized as a
# tensor rather than a tuple of tensor.
buffers = torch.zeros(
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)

layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
layer.punica_wrapper.add_expand(output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)

output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output


# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
Expand All @@ -51,34 +89,15 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros(
(x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device,
)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_stacked,
add_input=True)
# now have column partitioned output

output = output.view(*out_orig_shape)
return output
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)

@classmethod
@_fully_sharded_can_replace
Expand All @@ -99,46 +118,6 @@ def can_replace_layer(
)


def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
"""
MergedColumnParallelLinearWithShardedLoRA and
MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffers = torch.zeros(
(n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
for idx in range(n):
layer.punica_wrapper.add_shrink(buffers[idx], x,
layer.lora_a_stacked[idx], 1.0)

buffers = tensor_model_parallel_all_gather(buffers)
layer.punica_wrapper.add_expand_packed_nslice(
output,
buffers,
layer.lora_b_stacked,
layer.bias_stacked,
1.0,
layer.output_slices,
)

output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output


class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA):
"""
Expand All @@ -162,8 +141,9 @@ def slice_lora_a(
]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)

@classmethod
Expand Down Expand Up @@ -195,31 +175,15 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):

def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_stacked,
add_input=True)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)

@classmethod
@_fully_sharded_can_replace
Expand Down Expand Up @@ -260,8 +224,9 @@ def slice_lora_a(
]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)

@classmethod
Expand Down Expand Up @@ -294,7 +259,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""

def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked.shape[2]
shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
Expand All @@ -303,20 +268,24 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
shard_size = self.bias_stacked.shape[2]
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked)
shard_size = self.lora_bias_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias

def apply(self, x: torch.Tensor) -> torch.Tensor:
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros(
(x.shape[0], self.lora_a_stacked.shape[2]),
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
Expand All @@ -330,12 +299,18 @@ def apply(self, x: torch.Tensor) -> torch.Tensor:
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked,
self.bias_stacked, start_idx,
shard_size)
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
self.lora_bias_stacked,
self.output_slices,
offset_start=offset_start,
add_input=True,
)
output = output.view(*out_orig_shape)
return output

Expand Down
Loading

0 comments on commit 571da8f

Please sign in to comment.