-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into loadams/switch-hf-home
- Loading branch information
Showing
22 changed files
with
639 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from .optimized_linear import OptimizedLinear | ||
from .config import LoRAConfig, QuantizationConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class LoRAConfig: | ||
""" | ||
Configuration settings for LoRAOptimizedLinear. | ||
Attributes: | ||
lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64. | ||
lora_alpha (float): LoRA scaling factor, default is 16. | ||
base_weight_sharding (int): The degree to which the base weights are sharded, | ||
should typically be set to the data-parallel world size to maximize the memory | ||
reduction benefits. Defaults to 1, which means this feature is disabled. | ||
""" | ||
lora_r: int = 64 | ||
lora_alpha: float = 16. | ||
base_weight_sharding: int = 1 | ||
|
||
|
||
@dataclass | ||
class QuantizationConfig: | ||
""" | ||
Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear, | ||
and QuantizedParameter | ||
Attributes: | ||
q_bits (int): The number of bits used for quantization. Default is 8. | ||
mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3. | ||
group_size (int): The size of the group used for quantization. Default is 512. | ||
""" | ||
q_bits: int = 8 | ||
mantissa_bits: int = 3 | ||
group_size: int = 512 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import torch | ||
import math | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from dataclasses import is_dataclass | ||
from deepspeed.accelerator import get_accelerator | ||
import deepspeed.comm as dist | ||
|
||
from .config import LoRAConfig, QuantizationConfig | ||
from .quantization import QuantizedParameter, QuantizedLinear | ||
|
||
|
||
class OptimizedLinear(nn.Module): | ||
""" | ||
Optimized version of nn.Linear that adds features such as: | ||
* LoRA w. base weight sharding | ||
* FP [6,8,12] quantization | ||
Arguments: | ||
input_dim: Required: size of each input sample | ||
output_dim: Required: size of each output sample | ||
bias: Optional: If set to False, the layer will not learn an additive bias. Default: False | ||
lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree | ||
quantization_config: Optional: QuantizationConfig defining quantization features | ||
dtype: Optional: parameter dtype, only supports bfloat16 currently | ||
Returns: | ||
Returns a new nn.Module depending on the input config. Either native | ||
torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear. | ||
""" | ||
|
||
def __new__(self, | ||
input_dim: int, | ||
output_dim: int, | ||
bias: bool = False, | ||
lora_config: LoRAConfig = None, | ||
quantization_config: QuantizationConfig = None, | ||
dtype=torch.bfloat16): | ||
|
||
if quantization_config is not None and not is_dataclass(quantization_config): | ||
raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") | ||
if lora_config is not None and not is_dataclass(lora_config): | ||
raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") | ||
if lora_config is None and quantization_config is None: | ||
# Everything disabled, fall back to normal nn.Linear | ||
self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) | ||
|
||
elif lora_config: | ||
# lora enabled, quantization may or may not be | ||
self = LoRAOptimizedLinear(input_dim=input_dim, | ||
output_dim=output_dim, | ||
bias=bias, | ||
lora_config=lora_config, | ||
quantization_config=quantization_config, | ||
dtype=dtype) | ||
|
||
elif quantization_config: | ||
# only quantization enabled, no lora | ||
self = QuantizedLinear(input_dim=input_dim, | ||
output_dim=output_dim, | ||
bias=bias, | ||
quantization_config=quantization_config, | ||
dtype=dtype) | ||
return self | ||
|
||
|
||
class LoRAOptimizedLinear(nn.Module): | ||
|
||
def __init__(self, | ||
input_dim: int, | ||
output_dim: int, | ||
bias: bool = False, | ||
lora_config: LoRAConfig = None, | ||
quantization_config: QuantizationConfig = None, | ||
device=None, | ||
dtype=torch.bfloat16): | ||
super().__init__() | ||
self.input_dim = input_dim | ||
self.output_dim = output_dim | ||
self.bias = bias | ||
self.lora_config = lora_config | ||
self.quantization_config = quantization_config | ||
device = get_accelerator().current_device() if device is None else device | ||
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" | ||
|
||
self.zero_shards = self.lora_config.base_weight_sharding | ||
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) | ||
w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype)) | ||
torch.nn.init.xavier_uniform_(w) | ||
|
||
if self.quantization_config is not None: | ||
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" | ||
self.base_weight = QuantizedParameter(w, quantization_config=quantization_config) | ||
else: | ||
self.base_weight = w | ||
|
||
self.base_weight.requires_grad = False | ||
|
||
# Use RS lora for now. | ||
self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r) | ||
# Keeping lora weights in bf16 precision for ease of training. | ||
self.lora_weight_1 = nn.Linear(self.input_dim, | ||
self.lora_config.lora_r, | ||
bias=self.bias, | ||
device=device, | ||
dtype=dtype) | ||
self.lora_weight_2 = nn.Linear(self.lora_config.lora_r, | ||
self.output_dim, | ||
bias=self.bias, | ||
device=device, | ||
dtype=dtype) | ||
self.lora_weight_1.weight.requires_grad = True | ||
self.lora_weight_2.weight.requires_grad = True | ||
|
||
def full_weight(self): | ||
# This assumes weights are evenly sharded across gpus. which might not be correct. | ||
# in that case, we should flatten before all_gather. | ||
local_weight = self.base_weight.dequantized() if isinstance(self.base_weight, | ||
QuantizedParameter) else self.base_weight | ||
tensor_list = [ | ||
torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype) | ||
for _ in range(self.zero_shards) | ||
] | ||
dist.all_gather(tensor_list, local_weight) | ||
weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1)) | ||
return weight | ||
|
||
def linear_without_F_linear(self, input, weight): | ||
output = torch.mm(input.reshape(-1, input.shape[-1]), weight) | ||
output = output.view(*input.shape[:-1], weight.shape[1]) | ||
return output | ||
|
||
def forward(self, input_tensor): | ||
# Gather the sharded base weight | ||
if self.zero_shards > 1: | ||
with torch.no_grad(): | ||
base_weight = self.full_weight() | ||
elif self.quantization_config: | ||
base_weight = self.base_weight.dequantized() | ||
else: | ||
base_weight = self.base_weight | ||
|
||
base_weight_output = F.linear(input_tensor, base_weight) | ||
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) | ||
return base_weight_output + self.lora_scaling_factor * lora_output |
Oops, something went wrong.