Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fsdp lora #435

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
93cdae8
attempt to wrfsdp wrap lora modules
danbider Jul 6, 2023
0ec0de1
Merge branch 'mosaicml:main' into feature/fsdp-lora
danbider Jul 6, 2023
20ab8b6
fsdp works by iterating over modulers
danbider Jul 6, 2023
57659c5
merged remote
danbider Jul 6, 2023
d6cf053
cleaned up fsdp loop for peft
danbider Jul 7, 2023
a44b641
robust peft import
danbider Jul 7, 2023
d957d55
fsdp known issue deleted
danbider Jul 7, 2023
e5e012d
more info in tutorial about fsdp
danbider Jul 7, 2023
f7b5e70
conditioning on peft installation for cpu tests
danbider Jul 7, 2023
1cf348c
Merge branch 'main' into feature/fsdp-lora
codestar12 Jul 7, 2023
6a1c172
Merge branch 'mosaicml:main' into feature/fsdp-lora
danbider Jul 9, 2023
a3f370c
moved lora model building to ComposerHFCausalLM
danbider Jul 11, 2023
082f71e
formatting
danbider Jul 11, 2023
058951d
updated tutorial to move lora config under model config
danbider Jul 12, 2023
f57c84f
Merge branch 'mosaicml:main' into feature/fsdp-lora
danbider Jul 15, 2023
cc7a8f9
Merge branch 'mosaicml:main' into feature/fsdp-lora
danbider Jul 24, 2023
433ae51
Merge branch 'main' into feature/fsdp-lora
dakinggg Aug 1, 2023
7c68c19
merged upstream main, fixed conflicts
danbider Aug 15, 2023
4118367
added typecheck for peft model
danbider Aug 15, 2023
5db8c74
more pyright fixes
danbider Aug 15, 2023
3a3342f
more typechecking in training script
danbider Aug 15, 2023
622e51d
Merge branch 'main' into feature/fsdp-lora
danbider Aug 16, 2023
9ec0f69
pyright following main merge
danbider Aug 16, 2023
a4439c5
model_config instead of cfg.model
danbider Aug 16, 2023
0a9e542
Update TUTORIAL.md
danbider Aug 17, 2023
9bc0b50
Update llmfoundry/models/hf/hf_fsdp.py
danbider Aug 17, 2023
f2fd418
DDP tutorial edit
danbider Aug 17, 2023
050267f
edit fsdp stuff
danbider Aug 21, 2023
c0f5148
fixed popping
danbider Aug 30, 2023
1c47c23
eliminated bnb dep
danbider Aug 30, 2023
5b905b0
Merge branch 'feature/fsdp-lora' of https://github.com/danbider/llm-f…
danbider Aug 30, 2023
27d186d
Merge branch 'main' into feature/fsdp-lora
josejg Oct 21, 2023
2f59377
Update accelerate for peft
josejg Oct 23, 2023
5bc5240
Simplify LoRA validation logic
josejg Oct 23, 2023
79cf8d6
Proper import checking
josejg Oct 24, 2023
7f72c25
Fix indent
josejg Oct 30, 2023
02d949c
Prevent FDSP wrapping empty embedding LoRA attributes
josejg Oct 30, 2023
b955696
Merge branch 'main' into feature/fsdp-lora
josejg Oct 31, 2023
4a430bd
Fix bad indent
josejg Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,16 +330,41 @@ The majority of our training setups use `triton`. -->


### Can I finetune using PEFT / LoRA?
- The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so:
- The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `model.lora` arguments to the config `.yaml`, like so:
<!--pytest.mark.skip-->
```yaml
lora:
args:
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: ['Wqkv']
model:
name: hf_causal_lm
pretrained: true
...
lora:
args:
r: 16
lora_alpha: 32
target_modules: ["Wqkv", "out_proj", "up_proj", "down_proj"]
lora_dropout: 0.05
bias: none
task_type: "CAUSAL_LM"
```
You can train LoRA models either using FSDP for further memory savings. in your `.yaml`, specify:
danbider marked this conversation as resolved.
Show resolved Hide resolved
<!--pytest.mark.skip-->
```yaml
fsdp_config:
use_orig_params: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we confirm if this is necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will verify this tomorrow AM, good point

sharding_strategy: FULL_SHARD
mixed_precision: PURE
activation_checkpointing: true
activation_checkpointing_reentrant: false
activation_cpu_offload: false
limit_all_gathers: true
```
or default to DDP, as follows:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think to default DDP just leaving out the FSDP section entirely is a bit cleaner?

<!--pytest.mark.skip-->
```yaml
fsdp:
{}
```

- In the current release, these features have Beta support.
- For efficiency, The MPT model concatenates the `Q`, `K`, and `V` matrices in each attention block into a single `Wqkv` matrix that is three times wider. Currently, LoRA supports a low-rank approximation to this `Wqkv` matrix.
- When evaluating with PEFT / LoRA seperated weight, just set `pretrained_lora_id_or_path` in `model`(Find an example [here](scripts/eval/yamls/hf_lora_eval.yml#L19)).
Expand Down
228 changes: 119 additions & 109 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""

import os
from typing import Mapping, Union
from typing import Mapping

# required for loading a python model into composer
import transformers
from composer.metrics.nlp import (InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
Expand All @@ -16,7 +15,7 @@
LanguageCrossEntropy, LanguagePerplexity)
from composer.utils import dist
from omegaconf import DictConfig
from transformers import (AutoConfig, AutoModelForCausalLM,
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
PreTrainedTokenizerBase)

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
Expand All @@ -26,24 +25,40 @@
from llmfoundry.models.utils import init_empty_weights

try:
from peft.peft_model import PeftModel
model_types = PeftModel, transformers.PreTrainedModel
_om_model_config_type = Union[DictConfig, PeftModel,
transformers.PreTrainedModel]
from peft import LoraConfig, PeftModel, get_peft_model
_peft_installed = True
_model_type = PeftModel

except ImportError:
model_types = transformers.PreTrainedModel
_om_model_config_type = Union[DictConfig, transformers.PreTrainedModel]
# raising warnings below only if users try to use PEFT
_peft_installed = False
_model_type = None

__all__ = ['ComposerHFCausalLM']


def print_trainable_parameters(model: AutoModel) -> None:
# Prints the number of trainable parameters in the model.
if _model_type is None:
raise ImportError(
"PEFT not installed. Run pip install -e \".[gpu,peft]\"")
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f'trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}'
)


class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
"""Configures a :class:`.HuggingFaceModel` around a Causal LM.

Args:
om_model_config (DictConfig | PeftModel | transformers.PreTrainedModel): either an omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library.
if DictConfig, the following keys are required:
om_model_config (DictConfig): an omegaconf dictionary used to configure the model.
the following keys are required:
cfg.pretrained_model_name_or_path (str): The name of or local path to
the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel).
cfg.config_overrides (dict, optional): An optional dictionary of keyword
Expand All @@ -58,10 +73,8 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""

def __init__(
self,
om_model_config: _om_model_config_type, # type: ignore
tokenizer: PreTrainedTokenizerBase):
def __init__(self, om_model_config: DictConfig,
tokenizer: PreTrainedTokenizerBase):

# set up training and eval metrics
train_metrics = [
Expand All @@ -78,107 +91,104 @@ def __init__(
InContextLearningMCExpectedCalibrationError()
]

# if we are passed a DictConfig, we need to instantiate the model
if isinstance(om_model_config, DictConfig):

# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).'
)

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).'
)

attr = getattr(config, k)
if isinstance(attr, Mapping):
extra_keys = [
_k for _k in v.keys() if _k not in attr.keys()
]
if extra_keys:
raise ValueError(
f'Config dict override got unknown keys. ' +
f'Extra keys: {extra_keys}. ' +
f'Expected (a subset of) keys: {list(attr.keys())}.'
)
getattr(config, k).update(v)
else:
setattr(config, k, v)

# below we set up the device to initialize the model on
init_device = om_model_config.get('init_device', 'cpu')

# Get the device we want to initialize, and use the
# reolved version to initialize the HF model
resolved_init_device = hf_get_init_device(init_device)

# We need to have all non-zero local ranks be not-pretrained
# Rank 0 will still be pretrained, and distribute the weights appropriately
if dist.get_local_rank() != 0 and init_device == 'mixed':
om_model_config.pretrained = False

# initialize the model on the correct device
if resolved_init_device == 'cpu':
if om_model_config.pretrained:
model = AutoModelForCausalLM.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
config=config)
else:
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
)
elif resolved_init_device == 'meta':
if om_model_config.pretrained:
attr = getattr(config, k)
if isinstance(attr, Mapping):
extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
if extra_keys:
raise ValueError(
'Setting cfg.pretrained=True is not supported when init_device="meta".'
)
with init_empty_weights(include_buffers=False):
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
)
f'Config dict override got unknown keys. ' +
f'Extra keys: {extra_keys}. ' +
f'Expected (a subset of) keys: {list(attr.keys())}.')
getattr(config, k).update(v)
else:
setattr(config, k, v)

# below we set up the device to initialize the model on
init_device = om_model_config.get('init_device', 'cpu')

# Get the device we want to initialize, and use the
# reolved version to initialize the HF model
resolved_init_device = hf_get_init_device(init_device)

# We need to have all non-zero local ranks be not-pretrained
# Rank 0 will still be pretrained, and distribute the weights appropriately
if dist.get_local_rank() != 0 and init_device == 'mixed':
om_model_config.pretrained = False

# initialize the model on the correct device
if resolved_init_device == 'cpu':
if om_model_config.pretrained:
model = AutoModelForCausalLM.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
config=config)
else:
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
)
elif resolved_init_device == 'meta':
if om_model_config.pretrained:
raise ValueError(
f'init_device="{init_device}" must be either "cpu" or "meta".'
'Setting cfg.pretrained=True is not supported when init_device="meta".'
)
with init_empty_weights(include_buffers=False):
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
)

signal_file_path = '.local_rank0_completed_autoresume'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished downloading the checkpoint
dist.barrier()

if dist.get_local_rank() == 0:
os.remove(signal_file_path)

z_loss = om_model_config.get('z_loss', 0.0)

# elif the model is either a PeftModel or a PreTrainedModel
elif isinstance(om_model_config, model_types):
model = om_model_config
init_device = 'cpu'
z_loss = 0.0

# else, unsupported type
else:
raise ValueError(
f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}'
)
f'init_device="{init_device}" must be either "cpu" or "meta".')

signal_file_path = '.local_rank0_completed_autoresume'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished downloading the checkpoint
dist.barrier()

if dist.get_local_rank() == 0:
os.remove(signal_file_path)

z_loss = om_model_config.get('z_loss', 0.0)

# if om_model_config includes lora and peft is installed, add lora modules
lora_cfg = om_model_config.get('lora', None)
if lora_cfg is not None:
if _peft_installed == True:
print('Building Lora config...')
lora_cfg = LoraConfig(**lora_cfg.args)
print('Lora config built.')
print('Adding Lora modules...')
model = get_peft_model(model, lora_cfg)
print('Lora modules added.')
print_trainable_parameters(model)
else:
raise ImportError(
"cfg.model.lora is given but PEFT not installed. Run pip install -e \".[gpu,peft]\""
)

attention_patch_type = om_model_config.get('attention_patch_type', None)
if attention_patch_type is not None:
Expand Down
19 changes: 19 additions & 0 deletions llmfoundry/models/hf/hf_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
# which is MIT licensed

import functools
import warnings
from typing import Any, Iterable, List, Optional

import torch
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder

try:
from peft import LoraModel
lora_model_type = LoraModel
except ImportError:
lora_model_type = None
warnings.warn('peft is not installed, LoraModel will not be available')


# helper functions
def rhasattr(obj: Any, attr: str):
Expand Down Expand Up @@ -190,6 +198,17 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel,
tied_embeddings._fsdp_wrap = False # type: ignore
lm_head._fsdp_wrap = False # type: ignore

# applying ._fsdp_wrap = True for the LoRA modules
# this is needed because added LoRA modules have requires_grad=True,
# while the rest of the modules have requires_grad=False
if lora_model_type is not None: # peft is installed
if isinstance(model.base_model,
lora_model_type): # we have builR a LoraModel
danbider marked this conversation as resolved.
Show resolved Hide resolved
if model_block is not None: # for pyright
for name, module in model_block.named_modules():
if 'lora' in name: # peft adds modules named with lora
module._fsdp_wrap = True

# FSDP Wrap and Activation Checkpoint every model block
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
model.activation_checkpointing_fn = lambda module: isinstance(
Expand Down
Loading