Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement efficient packing without cross-contamination attention #4224

Merged
merged 13 commits into from
Jul 3, 2024
45 changes: 44 additions & 1 deletion src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,6 +22,46 @@
from transformers import DataCollatorForSeq2Seq


def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.

e.g.
```
[[1, 1, 2, 2, 2, 0]]
```
->
```
[
[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, o, x, x, x],
]
]
]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
bsz, seq_len = attention_mask_with_indices.size()
min_dtype = torch.finfo(dtype).min
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask = torch.where(expanded_mask != 0, 1, 0)
# Create a block-diagonal mask.
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
# Use the lower triangular mask to zero out the upper triangular part
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
return attention_mask_4d


@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Expand Down
14 changes: 11 additions & 3 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,30 @@ def preprocess_packed_supervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_labels = [], []
for length in knapsack:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])

if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn

if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")

model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)

return model_inputs
Expand Down
12 changes: 12 additions & 0 deletions src/llamafactory/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@

STAGES_USE_PAIR_DATA = {"rm", "dpo"}

SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
"falcon",
"gemma",
"gemma2",
"llama",
"mistral",
"phi",
"phi3",
"qwen2",
"starcoder2",
}

SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}

V_HEAD_WEIGHTS_NAME = "value_head.bin"
Expand Down
15 changes: 9 additions & 6 deletions src/llamafactory/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,19 @@ class DataArguments:
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
},
metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."},
)
val_size: float = field(
default=0.0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
)
packing: Optional[bool] = field(
default=None,
metadata={
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
},
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
)
neat_packing: bool = field(
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
tool_format: Optional[str] = field(
default=None,
Expand All @@ -112,3 +112,6 @@ def __post_init__(self):

if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")

if self.neat_packing and not self.packing:
raise ValueError("`neat_packing` requires `packing` is True.")
2 changes: 2 additions & 0 deletions src/llamafactory/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __post_init__(self):
self.compute_dtype: Optional["torch.dtype"] = None
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
self.model_max_length: Optional[int] = None
self.block_diag_attn: bool = False

if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
Expand Down Expand Up @@ -253,4 +254,5 @@ def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map
new_arg.model_max_length = old_arg.model_max_length
new_arg.block_diag_attn = old_arg.block_diag_attn
return new_arg
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")

if finetuning_args.stage != "sft" and data_args.neat_packing:
raise ValueError("`neat_packing` cannot be set as True except SFT.")

if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")

Expand Down Expand Up @@ -311,6 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:

model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len
model_args.block_diag_attn = data_args.neat_packing
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"

# Log on each process the small summary
Expand Down
147 changes: 147 additions & 0 deletions src/llamafactory/model/model_utils/packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2024 Musab Gultekin and the LlamaFactory team.
#
# This code is based on the Musab Gultekin's functionary library.
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2023 Musab Gultekin
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import TYPE_CHECKING, Tuple

import torch
import torch.nn.functional as F
import transformers.models

from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger


if TYPE_CHECKING:
from transformers import PretrainedConfig

from ...hparams import ModelArguments


logger = get_logger(__name__)


def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r"""
Gets the sequnce lengths in the current batch.

e.g.
```
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
```
->
```
[2, 3, 1, 2, 3]
```
"""
bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask)
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)

counts = counts.flatten()
seqlens = counts[counts.nonzero().squeeze()]
return seqlens


def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
r"""
Prepares the indices and seqlens for flash attn varlen function.

Returns:
indices: indices of non-masked tokens from the flattened sequence.
cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0.
max_seqlen_in_batch: the largest seqlen in the current batch.

e.g.
```
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
```
->
```
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
[0, 2, 5, 6, 8, 11]
3
```
"""
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch


def patch_for_block_diag_attn(model_type: str) -> None:
if model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data


def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return

model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
patch_for_block_diag_attn(model_type)
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support block diagonal attention.")
2 changes: 2 additions & 0 deletions src/llamafactory/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .model_utils.embedding import resize_embedding_layer
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
Expand Down Expand Up @@ -73,6 +74,7 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(config, model_args, is_trainable)

if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
Expand Down
12 changes: 6 additions & 6 deletions src/llamafactory/webui/components/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:

with gr.Row():
with gr.Column():
resize_vocab = gr.Checkbox()
packing = gr.Checkbox()
neat_packing = gr.Checkbox()

with gr.Column():
upcast_layernorm = gr.Checkbox()
resize_vocab = gr.Checkbox()
use_llama_pro = gr.Checkbox()

with gr.Column():
Expand All @@ -113,9 +113,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
warmup_steps,
neftune_alpha,
optim,
resize_vocab,
packing,
upcast_layernorm,
neat_packing,
resize_vocab,
use_llama_pro,
shift_attn,
report_to,
Expand All @@ -129,9 +129,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha,
optim=optim,
resize_vocab=resize_vocab,
packing=packing,
upcast_layernorm=upcast_layernorm,
neat_packing=neat_packing,
resize_vocab=resize_vocab,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
report_to=report_to,
Expand Down
Loading