From 5b4d0275bb45c2fb55a1bf7562bdafcd822a6203 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 16 Dec 2024 15:46:31 -0500 Subject: [PATCH 01/25] Basic evaluate CLI command / codepath (#2188) * basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders --- outputs | 1 + src/axolotl/train.py | 2 +- src/axolotl/utils/trainer.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) create mode 120000 outputs diff --git a/outputs b/outputs new file mode 120000 index 0000000000..be3c4a823f --- /dev/null +++ b/outputs @@ -0,0 +1 @@ +/workspace/data/axolotl-artifacts \ No newline at end of file diff --git a/src/axolotl/train.py b/src/axolotl/train.py index a74ecc2ec3..848831b665 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -27,7 +27,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer try: from optimum.bettertransformer import BetterTransformer diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..fd09b3eb67 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg): os.environ["TOKENIZERS_PARALLELISM"] = "false" +def set_pytorch_cuda_alloc_conf(): + """Set up CUDA allocation config if using PyTorch >= 2.2""" + torch_version = torch.__version__.split(".") + torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) + if torch_major == 2 and torch_minor >= 2: + if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: + os.environ[ + "PYTORCH_CUDA_ALLOC_CONF" + ] = "expandable_segments:True,roundup_power2_divisions:16" + + def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ): From 1e49a88005dd617fe738237bc94fcf38debb5e18 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 11 Dec 2024 14:51:53 -0500 Subject: [PATCH 02/25] initial diff attn layer / model conversion implementation (support for llama arch) --- .../integrations/diff_transformer/__init__.py | 0 .../integrations/diff_transformer/convert.py | 48 ++++ .../diff_transformer/multihead_diffattn.py | 230 ++++++++++++++++++ 3 files changed, 278 insertions(+) create mode 100644 src/axolotl/integrations/diff_transformer/__init__.py create mode 100644 src/axolotl/integrations/diff_transformer/convert.py create mode 100644 src/axolotl/integrations/diff_transformer/multihead_diffattn.py diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py new file mode 100644 index 0000000000..93a8df073b --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -0,0 +1,48 @@ +"""Differential attention conversion logic for a huggingface pre-trained model.""" +import logging + +from transformers import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.mistral.modeling_mistral import MistralAttention +from transformers.models.mixtral.modeling_mixtral import MixtralAttention + +from .multihead_diffattn import DifferentialAttention + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: + """Convert a pre-trained model's attention layers to differential attention""" + attention_patterns = (LlamaAttention, MistralAttention, MixtralAttention) + layer_idx = 0 + + # Get model dtype from existing weights + model_dtype = next(model.parameters()).dtype + + def convert_module(module): + nonlocal layer_idx + + # Iterate through module children, convert any attn layers to diff attn + for name, child in module.named_children(): + if isinstance(child, attention_patterns): + layer_type = type(child).__name__ + logger.info(f"Converting attention layer {layer_idx}: {layer_type}") + + # Create new diff attn layer + new_attention = DifferentialAttention( + config=module.config if hasattr(module, "config") else model.config, + layer_idx=layer_idx, + dtype=model_dtype, + ) + + # Replace the layer + setattr(module, name, new_attention) + layer_idx += 1 + elif len(list(child.children())) > 0: + convert_module(child) + + convert_module(model) + logger.info(f"Converted {layer_idx} attention layers to differential attention") + + return model diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py new file mode 100644 index 0000000000..00462475e2 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py @@ -0,0 +1,230 @@ +"""Re-implemention of differential attention.""" +# pylint: disable=invalid-name +import logging +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm +from transformers.models.llama.modeling_llama import ( + LlamaRotaryEmbedding, + apply_rotary_pos_emb, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + batch_size, n_kv_heads, slen, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, None, :, :] + .expand(batch_size, n_kv_heads, n_rep, slen, head_dim) + .reshape(batch_size, n_kv_heads * n_rep, slen, head_dim) + ) + + +def lambda_init_fn(depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + +class DifferentialAttention(nn.Module): + """Differential Attention implementation as described in the Diff Transformer paper. + + This implements a modified attention mechanism that computes the difference between + two attention patterns, scaled by learned lambda parameters. The mechanism helps + reduce noise in the attention weights for irrelevant / less relevant tokens. + + Key components: + - Split head dimension for differential computation + - Learned lambda parameters that control attention scaling + - Sublayer normalization on the attention output + + See: + - https://arxiv.org/abs/2410.05258 + - https://github.com/microsoft/unilm/tree/master/Diff-Transformer + + Args: + config: Model configuration object containing hidden size, number of heads etc. + layer_idx: Index of this layer in the transformer stack + dtype: Data type for the layer parameters + is_causal: Whether to use causal (masked) attention + """ + + def __init__( + self, + config: Any, + layer_idx: int, + dtype: torch.dtype, + is_causal: bool = True, + ): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.is_causal = is_causal + # self.head_dim = self.hidden_size // self.num_heads + self.head_dim = self.hidden_size // self.num_heads // 2 + self.num_key_value_heads = getattr( + config, "num_key_value_heads", self.num_heads + ) + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.scaling = (self.head_dim) ** -0.5 + + # Initialize projections with correct dtype + self.q_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=False, dtype=dtype + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.num_key_value_groups, + bias=False, + dtype=dtype, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.num_key_value_groups, + bias=False, + dtype=dtype, + ) + + self.o_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=False, dtype=dtype + ) + + # Initialize differential attention parameters + self.lambda_init = lambda_init_fn(self.layer_idx) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + ) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + ) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + ) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + ) + + self.subln = RMSNorm(2 * self.head_dim, eps=1e-5) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[tuple[torch.Tensor, torch.Tensor]], + ]: + bsz, tgt_len, _ = hidden_states.size() + + # Project queries, keys and values + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape for attention + q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, tgt_len, 2 * self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + v = v.view(bsz, tgt_len, self.num_key_value_heads, 2 * self.head_dim).transpose( + 1, 2 + ) + + # Generate or unpack cos, sin for rotary positional embeddings + if position_embeddings is None: + if position_ids is None: + position_ids = torch.arange( + 0, tgt_len, dtype=torch.long, device=q.device + ) + cos, sin = self.rotary_emb(q, position_ids) + else: + cos, sin = position_embeddings + + # Need to adjust cos, sin to match the halved head_dim + cos = cos[..., : self.head_dim] + sin = sin[..., : self.head_dim] + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + + # Update cache and get back concatenated states + k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + + # Prepare for attention + k = repeat_kv(k, self.num_key_value_groups) + v = repeat_kv(v, self.num_key_value_groups) + + # Scale query + q = q * self.scaling + + # Calculate attention scores + attn_weights = torch.matmul(q, k.transpose(-1, -2)) + + # Apply causal mask + if attention_mask is None: + attention_mask = torch.triu( + torch.full((tgt_len, tgt_len), float("-inf"), device=q.device), + diagonal=1, + ).type_as(attn_weights) + attn_weights = torch.nan_to_num(attn_weights) + attn_weights = attn_weights + attention_mask + + # Apply softmax + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( + attn_weights + ) + + # Calculate lambda + lambda_1 = torch.exp( + torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() + ).type_as(q) + lambda_2 = torch.exp( + torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() + ).type_as(q) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + # Apply differential attention + attn_weights = attn_weights.view( + bsz, self.num_heads, 2, -1, attn_weights.size(-1) + ) + attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1] + + # Apply attention to values + attn = torch.matmul(attn_weights, v) + + # Apply sublayer norm + attn = self.subln(attn).type_as(attn) + attn = attn * (1 - self.lambda_init) + + # Reshape and project output + attn = attn.transpose(1, 2).reshape( + bsz, tgt_len, self.num_heads * 2 * self.head_dim + ) + attn = self.o_proj(attn) + + # Return in exact format expected by LLaMA + if output_attentions: + return attn, attn_weights, past_key_value + return attn, None, past_key_value From 8c4ff51b3f677759069531a0c41f20a3914dedb9 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 11 Dec 2024 21:35:47 -0500 Subject: [PATCH 03/25] Adding script for doing conversion; fixes and updates --- scripts/convert_diff_transformer.py | 127 +++++++ .../integrations/diff_transformer/convert.py | 78 +++- .../diff_transformer/multihead_diffattn.py | 334 +++++++++++++----- 3 files changed, 454 insertions(+), 85 deletions(-) create mode 100644 scripts/convert_diff_transformer.py diff --git a/scripts/convert_diff_transformer.py b/scripts/convert_diff_transformer.py new file mode 100644 index 0000000000..651c0a229c --- /dev/null +++ b/scripts/convert_diff_transformer.py @@ -0,0 +1,127 @@ +"""Test conversion of transformers model attention to differential attention.""" +from typing import Tuple + +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, +) + +from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention + + +def setup_model( + model_name: str, device: str = "cuda" +) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + """Load model and tokenizer""" + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16, + device_map=device, + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return model, tokenizer + + +def convert_model_attention(model: AutoModelForCausalLM) -> AutoModelForCausalLM: + """Convert model to use differential attention""" + try: + model = convert_to_diff_attention(model) + return model + except Exception as exception: + print(f"Error during model conversion: {exception}") + raise + + +def test_inference(model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> None: + """Run test inference""" + # Test prompts + test_prompts = [ + "The quick brown fox", + ] + + for prompt in test_prompts: + try: + # Tokenize + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(model.device) for k, v in inputs.items()} + + # Generate + from time import time + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + # temperature=0.7, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + # use_cache=True, + ) + elasped = time() - start + print(f"generation time: {elasped}s") + + # Decode + print(outputs) + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"\nPrompt: {prompt}") + print(f"Generated: {generated_text}\n") + + except Exception as exception: + print(f"Error during inference: {str(exception)}") + raise + + +def save_converted_model(model: AutoModelForCausalLM, output_dir: str) -> None: + """Save the converted model""" + print(f"Saving converted model to {output_dir}") + model.save_pretrained(output_dir) + + +def main(): + # Configuration + model_name = "HuggingFaceTB/SmolLM2-135M" + # model_name = "openlm-research/open_llama_3b_v2" + output_dir = "./converted_model" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + try: + # Load model and tokenizer + model, tokenizer = setup_model(model_name, device) + + # Print original model info + print("Original model config:") + print(f"\t- Hidden size: {model.config.hidden_size}") + print(f"\t- Number of attention heads: {model.config.num_attention_heads}") + + # Test the original model + test_inference(model, tokenizer) + + # Convert to differential attention + model = convert_to_diff_attention(model) + model.to(model.device) + print("Model conversion completed") + + # Test the converted model + test_inference(model, tokenizer) + + # Save converted model + save_converted_model(model, output_dir) + + except Exception as exception: + print(f"Error during test: {str(exception)}") + raise + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 93a8df073b..36d97037b2 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -1,20 +1,81 @@ """Differential attention conversion logic for a huggingface pre-trained model.""" import logging +from typing import Union +import torch +from torch import nn from transformers import PreTrainedModel -from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention from transformers.models.mistral.modeling_mistral import MistralAttention from transformers.models.mixtral.modeling_mixtral import MixtralAttention -from .multihead_diffattn import DifferentialAttention +from .multihead_diffattn import ( + LlamaDifferentialAttention, + LlamaDifferentialSdpaAttention, +) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +def copy_attention_weights( + old_attn: Union[LlamaAttention, LlamaSdpaAttention], + new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention], + zero_init: bool = True, +) -> None: + """ + Copy weights from old attention layer to new differential attention layer. + Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence + to original attention mechanism. + """ + # For Q projection (Q1 and Q2) + new_q = torch.empty_like(new_attn.q_proj.weight.data) + new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1 + if zero_init: + new_q[new_attn.hidden_size :] = 0 + else: + nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1) + new_attn.q_proj.weight.data.copy_(new_q) + + # For K projection (K1 and K2) + old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads + new_k = torch.empty_like(new_attn.k_proj.weight.data) + new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1 + if zero_init: + new_k[old_kv_size:] = 0 + else: + nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1) + new_attn.k_proj.weight.data.copy_(new_k) + + # For V projection (single V) + new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data) + + # Output projection remains the same + new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data) + + # Zero out lambda parameters for exact equivalence + if zero_init: + nn.init.zeros_(new_attn.lambda_q1) + nn.init.zeros_(new_attn.lambda_k1) + nn.init.zeros_(new_attn.lambda_q2) + nn.init.zeros_(new_attn.lambda_k2) + new_attn.lambda_init = 0.0 + + logger.debug( + "Copied positive attention weights from %s to %s", + type(old_attn).__name__, + type(new_attn).__name__, + ) + + def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" - attention_patterns = (LlamaAttention, MistralAttention, MixtralAttention) + attention_patterns = ( + LlamaAttention, + LlamaSdpaAttention, + MistralAttention, + MixtralAttention, + ) layer_idx = 0 # Get model dtype from existing weights @@ -29,13 +90,22 @@ def convert_module(module): layer_type = type(child).__name__ logger.info(f"Converting attention layer {layer_idx}: {layer_type}") + # Choose appropriate differential attention class + if isinstance(child, LlamaSdpaAttention): + attention_class = LlamaDifferentialSdpaAttention + else: + attention_class = LlamaDifferentialAttention + # Create new diff attn layer - new_attention = DifferentialAttention( + new_attention = attention_class( config=module.config if hasattr(module, "config") else model.config, layer_idx=layer_idx, dtype=model_dtype, ) + # Copy weights from old attention to new attention + copy_attention_weights(child, new_attention) + # Replace the layer setattr(module, name, new_attention) layer_idx += 1 diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py index 00462475e2..6d3bc75898 100644 --- a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py +++ b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py @@ -6,9 +6,9 @@ import torch import torch.nn.functional as F +import transformers from torch import nn from transformers.cache_utils import Cache -from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm from transformers.models.llama.modeling_llama import ( LlamaRotaryEmbedding, apply_rotary_pos_emb, @@ -34,7 +34,7 @@ def lambda_init_fn(depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) -class DifferentialAttention(nn.Module): +class LlamaDifferentialAttention(nn.Module): """Differential Attention implementation as described in the Diff Transformer paper. This implements a modified attention mechanism that computes the difference between @@ -54,7 +54,6 @@ class DifferentialAttention(nn.Module): config: Model configuration object containing hidden size, number of heads etc. layer_idx: Index of this layer in the transformer stack dtype: Data type for the layer parameters - is_causal: Whether to use causal (masked) attention """ def __init__( @@ -62,43 +61,52 @@ def __init__( config: Any, layer_idx: int, dtype: torch.dtype, - is_causal: bool = True, ): super().__init__() - self.config = config - self.layer_idx = layer_idx + # Base model dimensions + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.is_causal = is_causal - # self.head_dim = self.hidden_size // self.num_heads - self.head_dim = self.hidden_size // self.num_heads // 2 - self.num_key_value_heads = getattr( - config, "num_key_value_heads", self.num_heads - ) - self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.base_num_heads = config.num_attention_heads + self.base_num_kv_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + + self.scaling = self.head_dim**-0.5 + self.layer_idx = layer_idx self.max_position_embeddings = config.max_position_embeddings - self.scaling = (self.head_dim) ** -0.5 + self.rope_theta = config.rope_theta + self.is_causal = True - # Initialize projections with correct dtype + # For Q1 and Q2 self.q_proj = nn.Linear( - self.hidden_size, self.hidden_size, bias=False, dtype=dtype + self.hidden_size, + self.hidden_size * 2, + bias=False, + dtype=dtype, ) + + # For K1 and K2 self.k_proj = nn.Linear( self.hidden_size, - self.hidden_size // self.num_key_value_groups, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, bias=False, dtype=dtype, ) + + # Single V projection self.v_proj = nn.Linear( self.hidden_size, - self.hidden_size // self.num_key_value_groups, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads, bias=False, dtype=dtype, ) + # Output projection self.o_proj = nn.Linear( - self.hidden_size, self.hidden_size, bias=False, dtype=dtype + self.hidden_size, + self.hidden_size, + bias=False, + dtype=dtype, ) # Initialize differential attention parameters @@ -116,7 +124,6 @@ def __init__( torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) ) - self.subln = RMSNorm(2 * self.head_dim, eps=1e-5) self.rotary_emb = LlamaRotaryEmbedding(config=config) def forward( @@ -126,6 +133,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, # pylint: disable=unused-argument @@ -134,97 +142,261 @@ def forward( Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]], ]: - bsz, tgt_len, _ = hidden_states.size() + bsz, q_len, _ = hidden_states.size() - # Project queries, keys and values - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) + # Project to Q1,Q2 and K1,K2 + qp = self.q_proj(hidden_states) + kp = self.k_proj(hidden_states) v = self.v_proj(hidden_states) - # Reshape for attention - q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim).transpose(1, 2) - k = k.view(bsz, tgt_len, 2 * self.num_key_value_heads, self.head_dim).transpose( - 1, 2 - ) - v = v.view(bsz, tgt_len, self.num_key_value_heads, 2 * self.head_dim).transpose( - 1, 2 - ) + # Split into Q1,Q2 and K1,K2 + q1, q2 = qp.chunk(2, dim=-1) + k1, k2 = kp.chunk(2, dim=-1) + + # Reshape Q1,Q2 for attention + q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - # Generate or unpack cos, sin for rotary positional embeddings + # Reshape K1,K2 for attention + k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Reshape V + v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Apply rotary embeddings if position_embeddings is None: if position_ids is None: - position_ids = torch.arange( - 0, tgt_len, dtype=torch.long, device=q.device - ) - cos, sin = self.rotary_emb(q, position_ids) + position_ids = torch.arange(q_len, device=q1.device) + cos, sin = self.rotary_emb(q1, position_ids) else: cos, sin = position_embeddings - # Need to adjust cos, sin to match the halved head_dim - cos = cos[..., : self.head_dim] - sin = sin[..., : self.head_dim] - q, k = apply_rotary_pos_emb(q, k, cos, sin) + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) + q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - - # Update cache and get back concatenated states + k = torch.stack([k1, k2], dim=1) k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + k1, k2 = k.unbind(dim=1) + + # Repeat KV heads to match Q heads + k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) + k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) + v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) + + # Calculate attention scores for both parts + # NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor + # instead of sqrt(head_dim). This could be set on the class as `self.scaling`. + attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt( + self.head_dim + ) + attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt( + self.head_dim + ) - # Prepare for attention - k = repeat_kv(k, self.num_key_value_groups) - v = repeat_kv(v, self.num_key_value_groups) - - # Scale query - q = q * self.scaling - - # Calculate attention scores - attn_weights = torch.matmul(q, k.transpose(-1, -2)) - - # Apply causal mask - if attention_mask is None: - attention_mask = torch.triu( - torch.full((tgt_len, tgt_len), float("-inf"), device=q.device), - diagonal=1, - ).type_as(attn_weights) - attn_weights = torch.nan_to_num(attn_weights) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : k1.shape[-2]] + attn_weights1 = attn_weights1 + causal_mask + attn_weights2 = attn_weights2 + causal_mask - # Apply softmax - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( - attn_weights + # Apply softmax separately as per paper + attn_weights1 = F.softmax(attn_weights1, dim=-1, dtype=torch.float32).type_as( + attn_weights1 + ) + attn_weights2 = F.softmax(attn_weights2, dim=-1, dtype=torch.float32).type_as( + attn_weights2 + ) + attn_weights1 = F.dropout( + attn_weights1, p=self.attention_dropout, training=self.training + ) + attn_weights2 = F.dropout( + attn_weights2, p=self.attention_dropout, training=self.training ) # Calculate lambda lambda_1 = torch.exp( torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() - ).type_as(q) + ).type_as(q1) lambda_2 = torch.exp( torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() - ).type_as(q) + ).type_as(q1) lambda_full = lambda_1 - lambda_2 + self.lambda_init - # Apply differential attention - attn_weights = attn_weights.view( - bsz, self.num_heads, 2, -1, attn_weights.size(-1) - ) - attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1] + # Compute differential attention (following paper's formula) + attn_weights = attn_weights1 - lambda_full * attn_weights2 - # Apply attention to values + # Apply attention weights to values attn = torch.matmul(attn_weights, v) - # Apply sublayer norm - attn = self.subln(attn).type_as(attn) + # Apply sublayer norm and scaling + # NOTE(Dan): The differential transformers paper applies sublayer normalization at this + # point, but this is typically done outside of the attention layer. It would look something + # like: `attn = self.subln(attn).type_as(attn)`, using `LlamaRMSNorm` or similar. attn = attn * (1 - self.lambda_init) - # Reshape and project output - attn = attn.transpose(1, 2).reshape( - bsz, tgt_len, self.num_heads * 2 * self.head_dim - ) + # Reshape to output + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) attn = self.o_proj(attn) - # Return in exact format expected by LLaMA if output_attentions: return attn, attn_weights, past_key_value return attn, None, past_key_value + + +class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention): + """Differential Attention implementation as described in the Diff Transformer paper. + This implements the same logic as `LlamaDifferentialAttention`, but uses + `scaled_dot_product_attention` instead of "manually" computing it under the hood. + + This implements a modified attention mechanism that computes the difference between + two attention patterns, scaled by learned lambda parameters. The mechanism helps + reduce noise in the attention weights for irrelevant / less relevant tokens. + + Key components: + - Split head dimension for differential computation + - Learned lambda parameters that control attention scaling + - Sublayer normalization on the attention output + + See: + - https://arxiv.org/abs/2410.05258 + - https://github.com/microsoft/unilm/tree/master/Diff-Transformer + + Args: + config: Model configuration object containing hidden size, number of heads etc. + layer_idx: Index of this layer in the transformer stack + dtype: Data type for the layer parameters + """ + + def forward( + self, + hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[tuple[torch.Tensor, torch.Tensor]], + ]: + if output_attentions: + transformers.logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + # Project to Q1,Q2 and K1,K2 + qp = self.q_proj(hidden_states) + kp = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Split into Q1,Q2 and K1,K2 + q1, q2 = qp.chunk(2, dim=-1) + k1, k2 = kp.chunk(2, dim=-1) + + # Reshape Q1,Q2 for attention + q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + # Reshape K1,K2 for attention + k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + # Reshape V + v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + + # Apply rotary embeddings + if position_embeddings is None: + if position_ids is None: + position_ids = torch.arange(q_len, device=q1.device) + cos, sin = self.rotary_emb(q1, position_ids) + else: + cos, sin = position_embeddings + + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) + q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k = torch.stack([k1, k2], dim=1) + k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + k1, k2 = k.unbind(dim=1) + + # Repeat KV heads to match Q heads + k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) + k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) + v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) + + causal_mask = None + if attention_mask is not None: + causal_mask = attention_mask + causal_mask = causal_mask[:, :, :, : k1.shape[-2]] + + # SDPA with memory-efficient backend requires contiguous inputs on CUDA + if q1.device.type == "cuda" and causal_mask is not None: + q1, q2 = q1.contiguous(), q2.contiguous() + k1, k2 = k1.contiguous(), k2.contiguous() + v = v.contiguous() + + # Calculate attention using SDPA + is_causal = attention_mask is None and q_len > 1 + + attn_output1 = F.scaled_dot_product_attention( + q1, + k1, + v, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output2 = F.scaled_dot_product_attention( + q2, + k2, + v, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + # Calculate lambda + lambda_1 = torch.exp( + torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() + ).type_as(q1) + lambda_2 = torch.exp( + torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() + ).type_as(q1) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + # Combine the attention outputs + attn = attn_output1 - lambda_full * attn_output2 + + # Apply sublayer norm and scaling + attn = attn * (1 - self.lambda_init) + + # Reshape to output + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + attn = self.o_proj(attn) + + if output_attentions: + return ( + attn, + None, + past_key_value, + ) # Note: can't return attn_weights with SDPA + return attn, None, past_key_value From 8264c62a0557d3eaff0c6b4c84091674d102e30e Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 11 Dec 2024 23:11:19 -0500 Subject: [PATCH 04/25] adding CLI command for convert-diff-transformer --- src/axolotl/cli/integrations/__init__.py | 0 .../integrations/convert_diff_transformer.py | 131 ++++++++++++++++++ src/axolotl/cli/main.py | 18 +++ 3 files changed, 149 insertions(+) create mode 100644 src/axolotl/cli/integrations/__init__.py create mode 100644 src/axolotl/cli/integrations/convert_diff_transformer.py diff --git a/src/axolotl/cli/integrations/__init__.py b/src/axolotl/cli/integrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py new file mode 100644 index 0000000000..8886c49463 --- /dev/null +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -0,0 +1,131 @@ +"""CLI to convert a transformers model's attns to diff attns.""" +import logging +import warnings +from pathlib import Path +from time import time +from typing import Union + +import fire +import torch +from colorama import Fore +from dotenv import load_dotenv +from transformers import HfArgumentParser + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention + +LOG = logging.getLogger("axolotl.cli.convert_attention") + + +def test_inference(model, tokenizer, prompt="The quick brown fox"): + """Run test inference and return generation time""" + try: + inputs = tokenizer(prompt, return_tensors="pt") + inputs = { + k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() + } + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) + elapsed = time() - start + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + LOG.info("Prompt: %s", prompt) + LOG.info("Generated: %s", generated_text) + LOG.info("Generation time: %.2fs", elapsed) + + return elapsed, generated_text + + except Exception as exc: + LOG.error("Inference failed: %s", str(exc)) + raise + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + print_axolotl_text_art() + + cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(TrainerCliArgs) + cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + try: + # Load model and tokenizer + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model.to(cfg.device, dtype=cfg.torch_dtype) + + # Log original model info + LOG.info( + "Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d", + model.config.hidden_size, + model.config.num_attention_heads, + ) + + # Test original model + LOG.info("Testing original model...") + orig_time, orig_text = test_inference(model, tokenizer) + + # Convert attention + LOG.info("Converting to differential attention...") + try: + model = convert_to_diff_attention(model) + model.to(model.device) + except Exception as exc: + LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) + raise + + # Test converted model + LOG.info("Testing converted model...") + conv_time, conv_text = test_inference(model, tokenizer) + + # Save if requested + if cfg.output_dir: + LOG.info("Saving converted model to %s", cfg.output_dir) + model.save_pretrained(cfg.output_dir) + + LOG.info( + Fore.GREEN + + "Conversion successful!\n" + + f"Original generation time: {orig_time:.2f}s\n" + + f"Converted generation time: {conv_time:.2f}s" + + Fore.RESET + ) + + if orig_text == conv_text: + LOG.info( + Fore.GREEN + + "Generations match!\n" + + f"Model generation: {orig_text}\n" + + Fore.RESET + ) + else: + LOG.info( + Fore.RED + + "Generations do not match.\n" + + f"Original generation: {orig_text}\n" + + f"Converted generation: {conv_text}\n" + + Fore.RESET + ) + + except Exception as exc: + LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc)) + raise + + +if __name__ == "__main__": + load_dotenv() + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + ) + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 14803e43ba..7743d50175 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -240,6 +240,24 @@ def merge_lora( do_cli(config=config, **kwargs) +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option( + "--output-dir", + type=click.Path(path_type=str), + help="Directory to save converted model", +) +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def convert_diff_transformer(config: str, **kwargs): + """Convert model attention layers to differential attention layers.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + from axolotl.cli.integrations.convert_diff_transformer import do_cli + + do_cli(config=config, **kwargs) + + @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") From 4bdbb2fd6c6a9ffcd99339cac5d148d3ef7fdf16 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 13 Dec 2024 00:06:22 -0500 Subject: [PATCH 05/25] training fixes, patching, minor cleanup --- .../integrations/convert_diff_transformer.py | 41 ++++++++---- src/axolotl/cli/main.py | 5 -- src/axolotl/core/trainer_builder.py | 2 +- .../integrations/diff_transformer/convert.py | 5 +- .../integrations/diff_transformer/patches.py | 46 ++++++++++++++ src/axolotl/train.py | 2 +- .../config/models/input/v0_4_1/__init__.py | 2 + src/axolotl/utils/models.py | 62 ++++++++++++++++--- 8 files changed, 136 insertions(+), 29 deletions(-) create mode 100644 src/axolotl/integrations/diff_transformer/patches.py diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 8886c49463..116a60480e 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -7,6 +7,7 @@ import fire import torch +import yaml from colorama import Fore from dotenv import load_dotenv from transformers import HfArgumentParser @@ -50,13 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"): raise -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - print_axolotl_text_art() - - cfg = load_cfg(config, **kwargs) - parser = HfArgumentParser(TrainerCliArgs) - cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) - +def convert_diff_transformer(cfg, cli_args, config_path): try: # Load model and tokenizer with warnings.catch_warnings(): @@ -90,8 +85,26 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): # Save if requested if cfg.output_dir: + # Save model and tokenizer LOG.info("Saving converted model to %s", cfg.output_dir) model.save_pretrained(cfg.output_dir) + tokenizer.save_pretrained(cfg.output_dir) + + # Modify config to reflect new path / differential attention + output_config_path = Path(cfg.output_dir) / "axolotl_config.yml" + LOG.info("Saving updated config to %s", output_config_path) + + with open(config_path, "r", encoding="utf-8") as file: + data = yaml.safe_load(file) or {} + + data["base_model"] = cfg.output_dir + data["diff_attention"] = True + + with open(output_config_path, "w", encoding="utf-8") as file: + yaml.dump(data, file) + else: + LOG.info("Not saving converted model to disk") + LOG.info("Pass --output-dir path/to/save to save model") LOG.info( Fore.GREEN @@ -122,10 +135,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): raise +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + print_axolotl_text_art() + + cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(TrainerCliArgs) + cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + convert_diff_transformer(cfg, cli_args, config) + + if __name__ == "__main__": load_dotenv() - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", - ) fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 7743d50175..360af4810e 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -242,11 +242,6 @@ def merge_lora( @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) -@click.option( - "--output-dir", - type=click.Path(path_type=str), - help="Directory to save converted model", -) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) def convert_diff_transformer(config: str, **kwargs): diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fffddac815..9e2a54d5da 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -294,7 +294,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ Training arguments for Causal trainer - This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value + This code is duplicated due to HF TrainingArguments not setting output_dir with a default value so it can't be used as a mixin. """ diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 36d97037b2..584c19d5fd 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -88,7 +88,6 @@ def convert_module(module): for name, child in module.named_children(): if isinstance(child, attention_patterns): layer_type = type(child).__name__ - logger.info(f"Converting attention layer {layer_idx}: {layer_type}") # Choose appropriate differential attention class if isinstance(child, LlamaSdpaAttention): @@ -96,6 +95,10 @@ def convert_module(module): else: attention_class = LlamaDifferentialAttention + logger.info( + f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}" + ) + # Create new diff attn layer new_attention = attention_class( config=module.config if hasattr(module, "config") else model.config, diff --git a/src/axolotl/integrations/diff_transformer/patches.py b/src/axolotl/integrations/diff_transformer/patches.py new file mode 100644 index 0000000000..7ff35633cb --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/patches.py @@ -0,0 +1,46 @@ +"""Patches related to differential transformers implementation.""" +from transformers import PreTrainedModel +from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES + +from .multihead_diffattn import ( + LlamaDifferentialAttention, + LlamaDifferentialSdpaAttention, +) + + +def patch_transformers(): + """Patch transformers to support differential attention""" + + # Add our attention class to the registry + LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention + LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention + + # Store original method for use in our patch + # original_autoset = PreTrainedModel._autoset_attn_implementation + + @classmethod + def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument + config._attn_implementation_autoset = True # pylint: disable=protected-access + attn_implementation = getattr(config, "_attn_implementation", None) + + valid_impls = [ + None, + "eager", + "sdpa", + "flash_attention_2", + "differential_eager", + "differential_sdpa", + ] + if attn_implementation not in valid_impls: + message = ( + f"Specified `attn_implementation={attn_implementation}` is not supported. " + f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}" + ) + raise ValueError(message + ".") + + return config + + # Apply patch + PreTrainedModel._autoset_attn_implementation = ( # pylint: disable=protected-access + new_autoset + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 848831b665..c0aac93244 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -88,7 +88,7 @@ def train( ) resume_from_checkpoint = cfg.resume_from_checkpoint - # Load the model and tokenizer + # Load the model msg = "loading model" if cfg.adapter: msg += " and peft_config..." diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0781c67989..d98cb3cb68 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -726,6 +726,8 @@ class Config: eager_attention: Optional[bool] = None + diff_attention: Optional[bool] = None + unsloth_cross_entropy_loss: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 523fd76feb..7d8599c3c3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -710,24 +710,60 @@ def set_attention_config(self) -> None: """ sample packing uses custom FA2 patch """ + print( + self.cfg.flash_attention, + self.cfg.sdp_attention, + self.cfg.eager_attention, + self.cfg.diff_attention, + ) + if self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) + + if self.cfg.diff_attention: + self.model_kwargs[ + "attn_implementation" + ] = "differential_flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "differential_flash_attention_2" + ) + else: + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) elif self.cfg.sdp_attention: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "sdpa" - ) + if self.cfg.diff_attention: + self.model_kwargs["attn_implementation"] = "differential_sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "differential_sdpa" + ) + else: + self.model_kwargs["attn_implementation"] = "sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "sdpa" + ) elif self.cfg.eager_attention: - self.model_kwargs["attn_implementation"] = "eager" + if self.cfg.diff_attention: + self.model_kwargs["attn_implementation"] = "differential_eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "differential_eager" + ) + else: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" + ) + elif self.cfg.diff_attention: + self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" + "differential_eager" ) + if "attn_implementation" in self.model_kwargs: + print(self.model_kwargs["attn_implementation"]) + if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True @@ -816,6 +852,8 @@ def _configure_zero3_memory_efficient_loading(): if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config + + # self.model._attn_implementation_autoset = False self.model = self.AutoModelLoader.from_pretrained( self.base_model, config=self.model_config, @@ -1030,6 +1068,10 @@ def apply_lora_patch(self) -> None: integrate_rope_embeddings() def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: + from axolotl.integrations.diff_transformer.patches import patch_transformers + + patch_transformers() + self.apply_patches() self.set_auto_model_loader() self.set_device_map_config() From 60a16687429f80b4a4c0e1ca4f0b0c967bdfffeb Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 13 Dec 2024 15:03:45 -0500 Subject: [PATCH 06/25] various improvemnents --- .../integrations/convert_diff_transformer.py | 89 +++++++++++++------ src/axolotl/cli/main.py | 4 +- src/axolotl/common/cli.py | 16 +++- .../integrations/diff_transformer/convert.py | 12 ++- .../diff_transformer/multihead_diffattn.py | 5 +- .../integrations/diff_transformer/patches.py | 1 + src/axolotl/utils/models.py | 10 --- 7 files changed, 87 insertions(+), 50 deletions(-) diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 116a60480e..a8c7e5942b 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -13,7 +13,7 @@ from transformers import HfArgumentParser from axolotl.cli import load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention LOG = logging.getLogger("axolotl.cli.convert_attention") @@ -67,21 +67,23 @@ def convert_diff_transformer(cfg, cli_args, config_path): ) # Test original model - LOG.info("Testing original model...") - orig_time, orig_text = test_inference(model, tokenizer) + if cli_args.debug: + LOG.info("Testing original model...") + orig_time, orig_text = test_inference(model, tokenizer) # Convert attention LOG.info("Converting to differential attention...") try: - model = convert_to_diff_attention(model) + model = convert_to_diff_attention(model, cli_args.zero_init) model.to(model.device) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) raise # Test converted model - LOG.info("Testing converted model...") - conv_time, conv_text = test_inference(model, tokenizer) + if cli_args.debug: + LOG.info("Testing converted model...") + conv_time, conv_text = test_inference(model, tokenizer) # Save if requested if cfg.output_dir: @@ -106,30 +108,65 @@ def convert_diff_transformer(cfg, cli_args, config_path): LOG.info("Not saving converted model to disk") LOG.info("Pass --output-dir path/to/save to save model") - LOG.info( - Fore.GREEN - + "Conversion successful!\n" - + f"Original generation time: {orig_time:.2f}s\n" - + f"Converted generation time: {conv_time:.2f}s" - + Fore.RESET - ) - - if orig_text == conv_text: + if cli_args.debug: LOG.info( Fore.GREEN - + "Generations match!\n" - + f"Model generation: {orig_text}\n" - + Fore.RESET - ) - else: - LOG.info( - Fore.RED - + "Generations do not match.\n" - + f"Original generation: {orig_text}\n" - + f"Converted generation: {conv_text}\n" + + "Conversion successful!\n" + + f"Original generation time: {orig_time:.2f}s\n" + + f"Converted generation time: {conv_time:.2f}s" + Fore.RESET ) + if orig_text == conv_text: + LOG.info( + Fore.GREEN + + "Generations match!\n" + + "Model generation:\n" + + "*" * 50 + + "\n" + + f"{orig_text}\n" + + "*" * 50 + + "\n" + + Fore.RESET + ) + else: + if cli_args.zero_init: + LOG.info( + Fore.RED + + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{orig_text}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{conv_text}\n" + + "*" * 50 + + "\n" + + Fore.RESET + ) + else: + LOG.info( + Fore.YELLOW + + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{orig_text}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{conv_text}\n" + + "*" * 50 + + "\n" + + "However, this is expected since --zero-init was not passed." + + Fore.RESET + ) except Exception as exc: LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc)) raise @@ -139,7 +176,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): print_axolotl_text_art() cfg = load_cfg(config, **kwargs) - parser = HfArgumentParser(TrainerCliArgs) + parser = HfArgumentParser(ConvertDiffTransformerCliArgs) cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) convert_diff_transformer(cfg, cli_args, config) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 360af4810e..c37aa5484a 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -12,7 +12,7 @@ build_command, fetch_from_github, ) -from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs +from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -242,7 +242,7 @@ def merge_lora( @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) -@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_dataclass(ConvertDiffTransformerCliArgs) @add_options_from_config(AxolotlInputConfig) def convert_diff_transformer(config: str, **kwargs): """Convert model attention layers to differential attention layers.""" diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 02ad9201b8..bdab7c272e 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -18,7 +18,7 @@ @dataclass class PreprocessCliArgs: """ - dataclass representing arguments for preprocessing only + dataclass with arguments for preprocessing only """ debug: bool = field(default=False) @@ -31,7 +31,7 @@ class PreprocessCliArgs: @dataclass class TrainerCliArgs: """ - dataclass representing the various non-training arguments + dataclass with various non-training arguments """ debug: bool = field(default=False) @@ -46,7 +46,7 @@ class TrainerCliArgs: @dataclass class EvaluateCliArgs: """ - dataclass representing the various evaluation arguments + dataclass with various evaluation arguments """ debug: bool = field(default=False) @@ -54,6 +54,16 @@ class EvaluateCliArgs: debug_num_examples: int = field(default=0) +@dataclass +class ConvertDiffTransformerCliArgs: + """ + dataclass with arguments for convert-diff-transformer CLI + """ + + debug: bool = field(default=False) + zero_init: bool = field(default=False) + + def load_model_and_tokenizer( *, cfg: DictDefault, diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 584c19d5fd..24bc07cf77 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -21,7 +21,7 @@ def copy_attention_weights( old_attn: Union[LlamaAttention, LlamaSdpaAttention], new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention], - zero_init: bool = True, + zero_init: bool = False, ) -> None: """ Copy weights from old attention layer to new differential attention layer. @@ -68,7 +68,9 @@ def copy_attention_weights( ) -def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: +def convert_to_diff_attention( + model: PreTrainedModel, zero_init: bool +) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" attention_patterns = ( LlamaAttention, @@ -78,9 +80,6 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: ) layer_idx = 0 - # Get model dtype from existing weights - model_dtype = next(model.parameters()).dtype - def convert_module(module): nonlocal layer_idx @@ -103,11 +102,10 @@ def convert_module(module): new_attention = attention_class( config=module.config if hasattr(module, "config") else model.config, layer_idx=layer_idx, - dtype=model_dtype, ) # Copy weights from old attention to new attention - copy_attention_weights(child, new_attention) + copy_attention_weights(child, new_attention, zero_init=zero_init) # Replace the layer setattr(module, name, new_attention) diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py index 6d3bc75898..ace9c58de6 100644 --- a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py +++ b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py @@ -60,11 +60,10 @@ def __init__( self, config: Any, layer_idx: int, - dtype: torch.dtype, ): super().__init__() - # Base model dimensions + # Base model config self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.base_num_heads = config.num_attention_heads @@ -77,6 +76,8 @@ def __init__( self.rope_theta = config.rope_theta self.is_causal = True + dtype = getattr(config, "torch_dtype", torch.float32) + # For Q1 and Q2 self.q_proj = nn.Linear( self.hidden_size, diff --git a/src/axolotl/integrations/diff_transformer/patches.py b/src/axolotl/integrations/diff_transformer/patches.py index 7ff35633cb..14117bf637 100644 --- a/src/axolotl/integrations/diff_transformer/patches.py +++ b/src/axolotl/integrations/diff_transformer/patches.py @@ -1,4 +1,5 @@ """Patches related to differential transformers implementation.""" + from transformers import PreTrainedModel from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7d8599c3c3..c8e08468fe 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -710,13 +710,6 @@ def set_attention_config(self) -> None: """ sample packing uses custom FA2 patch """ - print( - self.cfg.flash_attention, - self.cfg.sdp_attention, - self.cfg.eager_attention, - self.cfg.diff_attention, - ) - if self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass @@ -761,9 +754,6 @@ def set_attention_config(self) -> None: "differential_eager" ) - if "attn_implementation" in self.model_kwargs: - print(self.model_kwargs["attn_implementation"]) - if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True From 32f1b3ffe65ec78271929d8d80daff10aece1404 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 13 Dec 2024 15:17:52 -0500 Subject: [PATCH 07/25] various improvemnents --- scripts/convert_diff_transformer.py | 127 ------------------ .../diff_transformer/multihead_diffattn.py | 4 +- 2 files changed, 2 insertions(+), 129 deletions(-) delete mode 100644 scripts/convert_diff_transformer.py diff --git a/scripts/convert_diff_transformer.py b/scripts/convert_diff_transformer.py deleted file mode 100644 index 651c0a229c..0000000000 --- a/scripts/convert_diff_transformer.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Test conversion of transformers model attention to differential attention.""" -from typing import Tuple - -import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizer, -) - -from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention - - -def setup_model( - model_name: str, device: str = "cuda" -) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: - """Load model and tokenizer""" - model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16, - device_map=device, - ) - tokenizer = AutoTokenizer.from_pretrained(model_name) - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - return model, tokenizer - - -def convert_model_attention(model: AutoModelForCausalLM) -> AutoModelForCausalLM: - """Convert model to use differential attention""" - try: - model = convert_to_diff_attention(model) - return model - except Exception as exception: - print(f"Error during model conversion: {exception}") - raise - - -def test_inference(model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> None: - """Run test inference""" - # Test prompts - test_prompts = [ - "The quick brown fox", - ] - - for prompt in test_prompts: - try: - # Tokenize - inputs = tokenizer(prompt, return_tensors="pt") - inputs = {k: v.to(model.device) for k, v in inputs.items()} - - # Generate - from time import time - - start = time() - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=20, - num_beams=1, - do_sample=False, - # temperature=0.7, - pad_token_id=tokenizer.pad_token_id, - use_cache=False, - # use_cache=True, - ) - elasped = time() - start - print(f"generation time: {elasped}s") - - # Decode - print(outputs) - generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) - print(f"\nPrompt: {prompt}") - print(f"Generated: {generated_text}\n") - - except Exception as exception: - print(f"Error during inference: {str(exception)}") - raise - - -def save_converted_model(model: AutoModelForCausalLM, output_dir: str) -> None: - """Save the converted model""" - print(f"Saving converted model to {output_dir}") - model.save_pretrained(output_dir) - - -def main(): - # Configuration - model_name = "HuggingFaceTB/SmolLM2-135M" - # model_name = "openlm-research/open_llama_3b_v2" - output_dir = "./converted_model" - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Using device: {device}") - - try: - # Load model and tokenizer - model, tokenizer = setup_model(model_name, device) - - # Print original model info - print("Original model config:") - print(f"\t- Hidden size: {model.config.hidden_size}") - print(f"\t- Number of attention heads: {model.config.num_attention_heads}") - - # Test the original model - test_inference(model, tokenizer) - - # Convert to differential attention - model = convert_to_diff_attention(model) - model.to(model.device) - print("Model conversion completed") - - # Test the converted model - test_inference(model, tokenizer) - - # Save converted model - save_converted_model(model, output_dir) - - except Exception as exception: - print(f"Error during test: {str(exception)}") - raise - - -if __name__ == "__main__": - main() diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py index ace9c58de6..7b3db19ab5 100644 --- a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py +++ b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py @@ -292,8 +292,8 @@ def forward( 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, + hidden_states=hidden_states, # pylint: disable=duplicate-code + attention_mask=attention_mask, # pylint: disable=duplicate-code position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, From c1968eddf00675e8179e28d4c7d720ddfa92a625 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 17 Dec 2024 04:43:08 +0000 Subject: [PATCH 08/25] fix model save / load logic --- .../integrations/convert_diff_transformer.py | 2 +- src/axolotl/common/cli.py | 4 +-- src/axolotl/evaluate.py | 25 +++++++------------ .../integrations/diff_transformer/convert.py | 3 ++- .../diff_transformer/multihead_diffattn.py | 16 +++++++++--- .../integrations/diff_transformer/patches.py | 5 +--- src/axolotl/utils/models.py | 11 +++++--- 7 files changed, 35 insertions(+), 31 deletions(-) diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index a8c7e5942b..1cbf619c82 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -75,7 +75,7 @@ def convert_diff_transformer(cfg, cli_args, config_path): LOG.info("Converting to differential attention...") try: model = convert_to_diff_attention(model, cli_args.zero_init) - model.to(model.device) + model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) raise diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index bdab7c272e..2b25b7f395 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -4,7 +4,7 @@ import logging from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Union import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.logging_config import configure_logging @@ -67,7 +67,7 @@ class ConvertDiffTransformerCliArgs: def load_model_and_tokenizer( *, cfg: DictDefault, - cli_args: TrainerCliArgs, + cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs], ): LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index acf15e3fc7..bc1799960c 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -9,13 +9,13 @@ import torch from accelerate.logging import get_logger -from axolotl.common.cli import TrainerCliArgs +from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_processor, load_tokenizer -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.models import load_processor +from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -63,7 +63,7 @@ def evaluate_dataset( def evaluate( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta + *, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta ) -> Dict[str, float]: """ Evaluate a model on training and validation datasets @@ -83,12 +83,11 @@ def evaluate( # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() - # Load tokenizer - LOG.debug( - f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", - main_process_only=True, - ) - tokenizer = load_tokenizer(cfg) + # Load model + LOG.debug("loading model for evaluation...") + + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model = model.to(cfg.device, dtype=cfg.torch_dtype) # Load processor for multimodal models if needed processor = None @@ -100,12 +99,6 @@ def evaluate( eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps - # Load model - LOG.debug("loading model for evaluation...") - model, _ = load_model( - cfg, tokenizer, processor=processor, inference=cli_args.inference - ) - # Set up trainer trainer = setup_trainer( cfg, diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 24bc07cf77..bd688fadbe 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -59,7 +59,7 @@ def copy_attention_weights( nn.init.zeros_(new_attn.lambda_k1) nn.init.zeros_(new_attn.lambda_q2) nn.init.zeros_(new_attn.lambda_k2) - new_attn.lambda_init = 0.0 + nn.init.zeros_(new_attn.lambda_init) logger.debug( "Copied positive attention weights from %s to %s", @@ -105,6 +105,7 @@ def convert_module(module): ) # Copy weights from old attention to new attention + new_attention.to(child.q_proj.weight.device) copy_attention_weights(child, new_attention, zero_init=zero_init) # Replace the layer diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py index 7b3db19ab5..4735564452 100644 --- a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py +++ b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py @@ -70,13 +70,12 @@ def __init__( self.base_num_kv_heads = config.num_key_value_heads self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = self.head_dim**-0.5 self.layer_idx = layer_idx self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - dtype = getattr(config, "torch_dtype", torch.float32) + dtype = torch.float32 # For Q1 and Q2 self.q_proj = nn.Linear( @@ -111,7 +110,10 @@ def __init__( ) # Initialize differential attention parameters - self.lambda_init = lambda_init_fn(self.layer_idx) + self.lambda_init = nn.Parameter( + torch.full((), lambda_init_fn(self.layer_idx), dtype=dtype), + requires_grad=False, + ) self.lambda_q1 = nn.Parameter( torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) ) @@ -197,6 +199,14 @@ def forward( self.head_dim ) + # Add this debug step right after computing attention weights in the forward pass + attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt( + self.head_dim + ) + attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt( + self.head_dim + ) + if attention_mask is not None: causal_mask = attention_mask[:, :, :, : k1.shape[-2]] attn_weights1 = attn_weights1 + causal_mask diff --git a/src/axolotl/integrations/diff_transformer/patches.py b/src/axolotl/integrations/diff_transformer/patches.py index 14117bf637..37ad0a981b 100644 --- a/src/axolotl/integrations/diff_transformer/patches.py +++ b/src/axolotl/integrations/diff_transformer/patches.py @@ -9,16 +9,13 @@ ) -def patch_transformers(): +def patch_llama_attention_classes(): """Patch transformers to support differential attention""" # Add our attention class to the registry LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention - # Store original method for use in our patch - # original_autoset = PreTrainedModel._autoset_attn_implementation - @classmethod def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument config._attn_implementation_autoset = True # pylint: disable=protected-access diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c8e08468fe..3b0dcbc2ba 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -444,6 +444,13 @@ def apply_patches(self) -> None: patch_mistral_cross_entropy() + if self.cfg.diff_attention: + from axolotl.integrations.diff_transformer.patches import ( + patch_llama_attention_classes, + ) + + patch_llama_attention_classes() + def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: @@ -1058,10 +1065,6 @@ def apply_lora_patch(self) -> None: integrate_rope_embeddings() def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - from axolotl.integrations.diff_transformer.patches import patch_transformers - - patch_transformers() - self.apply_patches() self.set_auto_model_loader() self.set_device_map_config() From dbeea75155c635ecd7b64e6dbda479e8f5992039 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 17 Dec 2024 13:52:34 +0000 Subject: [PATCH 09/25] pre-commit fix --- src/axolotl/cli/main.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index c37aa5484a..e922d37e99 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -12,7 +12,12 @@ build_command, fetch_from_github, ) -from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs +from axolotl.common.cli import ( + ConvertDiffTransformerCliArgs, + EvaluateCliArgs, + PreprocessCliArgs, + TrainerCliArgs, +) from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig From 81d9ff4a8f13e0e1943a52df5d166fdaa1138755 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 17 Dec 2024 14:12:03 +0000 Subject: [PATCH 10/25] moving monkeypatch --- .../patches.py => monkeypatch/attention/differential.py} | 2 +- src/axolotl/utils/models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename src/axolotl/{integrations/diff_transformer/patches.py => monkeypatch/attention/differential.py} (95%) diff --git a/src/axolotl/integrations/diff_transformer/patches.py b/src/axolotl/monkeypatch/attention/differential.py similarity index 95% rename from src/axolotl/integrations/diff_transformer/patches.py rename to src/axolotl/monkeypatch/attention/differential.py index 37ad0a981b..037a6f0bd2 100644 --- a/src/axolotl/integrations/diff_transformer/patches.py +++ b/src/axolotl/monkeypatch/attention/differential.py @@ -3,7 +3,7 @@ from transformers import PreTrainedModel from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES -from .multihead_diffattn import ( +from axolotl.integrations.diff_transformer.multihead_diffattn import ( LlamaDifferentialAttention, LlamaDifferentialSdpaAttention, ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3b0dcbc2ba..e98e9f31b4 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -445,7 +445,7 @@ def apply_patches(self) -> None: patch_mistral_cross_entropy() if self.cfg.diff_attention: - from axolotl.integrations.diff_transformer.patches import ( + from axolotl.monkeypatch.attention.differential import ( patch_llama_attention_classes, ) From c74a290e4f8ea3926ca465bc231efbf2f7f60a10 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 17 Dec 2024 18:44:47 +0000 Subject: [PATCH 11/25] differential flash attention 2; cleanup --- model-out/eval_summary.csv | 6 + .../integrations/convert_diff_transformer.py | 66 +++-- src/axolotl/cli/utils.py | 12 +- src/axolotl/common/cli.py | 1 + .../__init__.py | 0 .../convert.py | 46 ++-- .../differential_attention.py} | 236 ++++++++++++++---- .../monkeypatch/attention/differential.py | 7 +- 8 files changed, 268 insertions(+), 106 deletions(-) create mode 100644 model-out/eval_summary.csv rename src/axolotl/integrations/{diff_transformer => differential_transformer}/__init__.py (100%) rename src/axolotl/integrations/{diff_transformer => differential_transformer}/convert.py (78%) rename src/axolotl/integrations/{diff_transformer/multihead_diffattn.py => differential_transformer/differential_attention.py} (64%) diff --git a/model-out/eval_summary.csv b/model-out/eval_summary.csv new file mode 100644 index 0000000000..ccbe73358c --- /dev/null +++ b/model-out/eval_summary.csv @@ -0,0 +1,6 @@ +metric,training,validation +loss,1.8773103952407837,1.915901780128479 +model_preparation_time,0.0051,0.0051 +runtime,89.7635,8.9565 +samples_per_second,20.053,22.33 +steps_per_second,20.053,22.33 diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 1cbf619c82..6eb00452b1 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -14,7 +14,9 @@ from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer -from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention +from axolotl.integrations.differential_transformer.convert import ( + convert_to_diff_attention, +) LOG = logging.getLogger("axolotl.cli.convert_attention") @@ -74,7 +76,11 @@ def convert_diff_transformer(cfg, cli_args, config_path): # Convert attention LOG.info("Converting to differential attention...") try: - model = convert_to_diff_attention(model, cli_args.zero_init) + model = convert_to_diff_attention( + model=model, + zero_init=cli_args.zero_init, + sublayer_norm=cli_args.sublayer_norm, + ) model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) @@ -130,43 +136,35 @@ def convert_diff_transformer(cfg, cli_args, config_path): + Fore.RESET ) else: - if cli_args.zero_init: - LOG.info( - Fore.RED - + "Generations do not match.\n" - + "Original generation:\n" - + "*" * 50 - + "\n" - + f"{orig_text}\n" - + "*" * 50 - + "\n" - + "Converted generation:\n" - + "*" * 50 - + "\n" - + f"{conv_text}\n" - + "*" * 50 - + "\n" - + Fore.RESET - ) + message = ( + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{orig_text}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{conv_text}\n" + + "*" * 50 + + "\n" + ) + + if cli_args.zero_init and not cli_args.sublayer_norm: + LOG.info(Fore.RED + message + Fore.RESET) else: LOG.info( Fore.YELLOW - + "Generations do not match.\n" - + "Original generation:\n" - + "*" * 50 - + "\n" - + f"{orig_text}\n" - + "*" * 50 - + "\n" - + "Converted generation:\n" - + "*" * 50 - + "\n" - + f"{conv_text}\n" - + "*" * 50 - + "\n" - + "However, this is expected since --zero-init was not passed." + + message + + "However, this is expected since --zero-init" + + " and --no-sublayer-norm were not passed." + Fore.RESET ) + + return model + except Exception as exc: LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc)) raise diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index f0e2573f72..a228ee92a3 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -22,7 +22,6 @@ def decorator(function): # Process dataclass fields in reverse order for correct option ordering for field in reversed(dataclasses.fields(config_class)): field_type = field.type - if get_origin(field_type) is Union and type(None) in get_args(field_type): field_type = next( t for t in get_args(field_type) if not isinstance(t, NoneType) @@ -44,6 +43,7 @@ def decorator(function): default=field.default, help=field.metadata.get("description"), )(function) + return function return decorator @@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]): def decorator(function): # Process model fields in reverse order for correct option ordering for name, field in reversed(config_class.model_fields.items()): - if field.annotation == bool: + field_type = field.annotation + if get_origin(field_type) is Union and type(None) in get_args(field_type): + field_type = next( + t for t in get_args(field_type) if not isinstance(t, NoneType) + ) + + # NOTE: defaults are handled by the pydantic model config classes. + if field_type == bool: field_name = name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" function = click.option( @@ -66,6 +73,7 @@ def decorator(function): function = click.option( option_name, default=None, help=field.description )(function) + return function return decorator diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 2b25b7f395..9c921e5640 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -62,6 +62,7 @@ class ConvertDiffTransformerCliArgs: debug: bool = field(default=False) zero_init: bool = field(default=False) + sublayer_norm: bool = field(default=True) def load_model_and_tokenizer( diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/differential_transformer/__init__.py similarity index 100% rename from src/axolotl/integrations/diff_transformer/__init__.py rename to src/axolotl/integrations/differential_transformer/__init__.py diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/differential_transformer/convert.py similarity index 78% rename from src/axolotl/integrations/diff_transformer/convert.py rename to src/axolotl/integrations/differential_transformer/convert.py index bd688fadbe..5620ad1995 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/differential_transformer/convert.py @@ -5,22 +5,35 @@ import torch from torch import nn from transformers import PreTrainedModel -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention -from transformers.models.mistral.modeling_mistral import MistralAttention -from transformers.models.mixtral.modeling_mixtral import MixtralAttention +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaFlashAttention2, + LlamaSdpaAttention, +) -from .multihead_diffattn import ( +from .differential_attention import ( LlamaDifferentialAttention, + LlamaDifferentialFlashAttention2, LlamaDifferentialSdpaAttention, ) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +ATTENTION_MAPPING = { + LlamaAttention: LlamaDifferentialAttention, + LlamaSdpaAttention: LlamaDifferentialSdpaAttention, + LlamaFlashAttention2: LlamaDifferentialFlashAttention2, +} + def copy_attention_weights( - old_attn: Union[LlamaAttention, LlamaSdpaAttention], - new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention], + old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2], + new_attn: Union[ + LlamaDifferentialAttention, + LlamaDifferentialSdpaAttention, + LlamaDifferentialFlashAttention2, + ], zero_init: bool = False, ) -> None: """ @@ -69,31 +82,24 @@ def copy_attention_weights( def convert_to_diff_attention( - model: PreTrainedModel, zero_init: bool + model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True ) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" - attention_patterns = ( - LlamaAttention, - LlamaSdpaAttention, - MistralAttention, - MixtralAttention, - ) layer_idx = 0 + # Set sublayer norm as config on the model. + model.config.sublayer_norm = sublayer_norm + def convert_module(module): nonlocal layer_idx # Iterate through module children, convert any attn layers to diff attn for name, child in module.named_children(): - if isinstance(child, attention_patterns): - layer_type = type(child).__name__ - + if isinstance(child, tuple(ATTENTION_MAPPING.keys())): # Choose appropriate differential attention class - if isinstance(child, LlamaSdpaAttention): - attention_class = LlamaDifferentialSdpaAttention - else: - attention_class = LlamaDifferentialAttention + attention_class = ATTENTION_MAPPING[type(child)] + layer_type = type(child).__name__ logger.info( f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}" ) diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/differential_transformer/differential_attention.py similarity index 64% rename from src/axolotl/integrations/diff_transformer/multihead_diffattn.py rename to src/axolotl/integrations/differential_transformer/differential_attention.py index 4735564452..2046f08bcf 100644 --- a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py +++ b/src/axolotl/integrations/differential_transformer/differential_attention.py @@ -7,9 +7,11 @@ import torch import torch.nn.functional as F import transformers +from flash_attn.flash_attn_interface import flash_attn_func from torch import nn from transformers.cache_utils import Cache from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, ) @@ -75,14 +77,11 @@ def __init__( self.rope_theta = config.rope_theta self.is_causal = True - dtype = torch.float32 - # For Q1 and Q2 self.q_proj = nn.Linear( self.hidden_size, self.hidden_size * 2, bias=False, - dtype=dtype, ) # For K1 and K2 @@ -90,7 +89,6 @@ def __init__( self.hidden_size, self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, bias=False, - dtype=dtype, ) # Single V projection @@ -98,7 +96,6 @@ def __init__( self.hidden_size, self.hidden_size // self.base_num_heads * self.base_num_kv_heads, bias=False, - dtype=dtype, ) # Output projection @@ -106,28 +103,33 @@ def __init__( self.hidden_size, self.hidden_size, bias=False, - dtype=dtype, ) # Initialize differential attention parameters self.lambda_init = nn.Parameter( - torch.full((), lambda_init_fn(self.layer_idx), dtype=dtype), + torch.full((), lambda_init_fn(self.layer_idx)), requires_grad=False, ) self.lambda_q1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) self.lambda_k1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) self.lambda_q2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) self.lambda_k2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) self.rotary_emb = LlamaRotaryEmbedding(config=config) + sublayer_norm = getattr(config, "sublayer_norm", True) + self.subln = ( + LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5) + if sublayer_norm + else nn.Identity() + ) def forward( self, @@ -192,39 +194,21 @@ def forward( # Calculate attention scores for both parts # NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor # instead of sqrt(head_dim). This could be set on the class as `self.scaling`. - attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt( - self.head_dim - ) - attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt( - self.head_dim - ) - - # Add this debug step right after computing attention weights in the forward pass - attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt( - self.head_dim - ) - attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt( - self.head_dim - ) + attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) + attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) if attention_mask is not None: causal_mask = attention_mask[:, :, :, : k1.shape[-2]] - attn_weights1 = attn_weights1 + causal_mask - attn_weights2 = attn_weights2 + causal_mask + attn1 = attn1 + causal_mask + attn2 = attn2 + causal_mask - # Apply softmax separately as per paper - attn_weights1 = F.softmax(attn_weights1, dim=-1, dtype=torch.float32).type_as( - attn_weights1 - ) - attn_weights2 = F.softmax(attn_weights2, dim=-1, dtype=torch.float32).type_as( - attn_weights2 - ) - attn_weights1 = F.dropout( - attn_weights1, p=self.attention_dropout, training=self.training - ) - attn_weights2 = F.dropout( - attn_weights2, p=self.attention_dropout, training=self.training - ) + # Apply softmax + attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1) + attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2) + + # Apply dropout + attn1 = F.dropout(attn1, p=self.attention_dropout, training=self.training) + attn2 = F.dropout(attn2, p=self.attention_dropout, training=self.training) # Calculate lambda lambda_1 = torch.exp( @@ -236,15 +220,13 @@ def forward( lambda_full = lambda_1 - lambda_2 + self.lambda_init # Compute differential attention (following paper's formula) - attn_weights = attn_weights1 - lambda_full * attn_weights2 + attn_weights = attn1 - lambda_full * attn2 # Apply attention weights to values attn = torch.matmul(attn_weights, v) # Apply sublayer norm and scaling - # NOTE(Dan): The differential transformers paper applies sublayer normalization at this - # point, but this is typically done outside of the attention layer. It would look something - # like: `attn = self.subln(attn).type_as(attn)`, using `LlamaRMSNorm` or similar. + attn = self.subln(attn) attn = attn * (1 - self.lambda_init) # Reshape to output @@ -368,20 +350,21 @@ def forward( # Calculate attention using SDPA is_causal = attention_mask is None and q_len > 1 - attn_output1 = F.scaled_dot_product_attention( + dropout_p = self.attention_dropout if self.training else 0.0 + attn1 = F.scaled_dot_product_attention( q1, k1, v, attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + dropout_p=dropout_p, is_causal=is_causal, ) - attn_output2 = F.scaled_dot_product_attention( + attn2 = F.scaled_dot_product_attention( q2, k2, v, attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + dropout_p=dropout_p, is_causal=is_causal, ) @@ -395,9 +378,10 @@ def forward( lambda_full = lambda_1 - lambda_2 + self.lambda_init # Combine the attention outputs - attn = attn_output1 - lambda_full * attn_output2 + attn = attn1 - lambda_full * attn2 # Apply sublayer norm and scaling + attn = self.subln(attn) attn = attn * (1 - self.lambda_init) # Reshape to output @@ -411,3 +395,157 @@ def forward( past_key_value, ) # Note: can't return attn_weights with SDPA return attn, None, past_key_value + + +class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention): + """Differential Attention implementation using Flash Attention 2. + This implements the same logic as `LlamaDifferentialAttention`, but uses + Flash Attention 2 for more efficient computation. + + This implements a modified attention mechanism that computes the difference between + two attention patterns, scaled by learned lambda parameters. The mechanism helps + reduce noise in the attention weights for irrelevant / less relevant tokens. + + Key components: + - Split head dimension for differential computation + - Learned lambda parameters that control attention scaling + - Sublayer normalization on the attention output + - Flash Attention 2 for efficient attention computation + + See: + - https://arxiv.org/abs/2410.05258 + - https://github.com/microsoft/unilm/tree/master/Diff-Transformer + + Args: + config: Model configuration object containing hidden size, number of heads etc. + layer_idx: Index of this layer in the transformer stack + dtype: Data type for the layer parameters + """ + + def forward( + self, + hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[tuple[torch.Tensor, torch.Tensor]], + ]: + if output_attentions: + transformers.logger.warning_once( + "LlamaModel is using LlamaFlashAttention, but Flash Attention does not support `output_attentions=True`. " + "Falling back to the manual attention implementation." + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + # Project to Q1,Q2 and K1,K2 + qp = self.q_proj(hidden_states) + kp = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Split into Q1,Q2 and K1,K2 + q1, q2 = qp.chunk(2, dim=-1) + k1, k2 = kp.chunk(2, dim=-1) + + # Reshape Q1,Q2 for attention + q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + # Reshape K1,K2 for attention + k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + # Reshape V + v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + + # Apply rotary embeddings + if position_embeddings is None: + if position_ids is None: + position_ids = torch.arange(q_len, device=q1.device) + cos, sin = self.rotary_emb(q1, position_ids) + else: + cos, sin = position_embeddings + + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) + q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k = torch.stack([k1, k2], dim=1) + k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + k1, k2 = k.unbind(dim=1) + + # Repeat KV heads to match Q heads + k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) + k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) + v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) + + q1 = q1.transpose(1, 2) + q2 = q2.transpose(1, 2) + k1 = k1.transpose(1, 2) + k2 = k2.transpose(1, 2) + v = v.transpose(1, 2) + + # Calculate attention using Flash Attention + dropout_p = self.attention_dropout if self.training else 0.0 + attn1 = flash_attn_func( + q1, + k1, + v, + dropout_p=dropout_p, + causal=True, + ) + attn2 = flash_attn_func( + q2, + k2, + v, + dropout_p=dropout_p, + causal=True, + ) + + attn1 = attn1.transpose(1, 2) + attn2 = attn2.transpose(1, 2) + + # Calculate lambda + lambda_1 = torch.exp( + torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() + ).type_as(q1) + lambda_2 = torch.exp( + torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() + ).type_as(q1) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + # Combine the attention outputs + attn = attn1 - lambda_full * attn2 + + # Apply sublayer norm and scaling + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + + # Reshape to output + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + attn = self.o_proj(attn) + + if output_attentions: + return ( + attn, + None, + past_key_value, + ) # Note: can't return attn_weights with Flash Attention + return attn, None, past_key_value diff --git a/src/axolotl/monkeypatch/attention/differential.py b/src/axolotl/monkeypatch/attention/differential.py index 037a6f0bd2..36e3821af6 100644 --- a/src/axolotl/monkeypatch/attention/differential.py +++ b/src/axolotl/monkeypatch/attention/differential.py @@ -3,8 +3,9 @@ from transformers import PreTrainedModel from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES -from axolotl.integrations.diff_transformer.multihead_diffattn import ( +from axolotl.integrations.differential_transformer.differential_attention import ( LlamaDifferentialAttention, + LlamaDifferentialFlashAttention2, LlamaDifferentialSdpaAttention, ) @@ -15,6 +16,9 @@ def patch_llama_attention_classes(): # Add our attention class to the registry LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention + LLAMA_ATTENTION_CLASSES[ + "differential_flash_attention_2" + ] = LlamaDifferentialFlashAttention2 @classmethod def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument @@ -28,6 +32,7 @@ def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument "flash_attention_2", "differential_eager", "differential_sdpa", + "differential_flash_attention_2", ] if attn_implementation not in valid_impls: message = ( From 12d14cce2769582d604ad8502b1129ee80b37bca Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 17 Dec 2024 18:54:49 +0000 Subject: [PATCH 12/25] duplicate code ignore --- .../differential_transformer/differential_attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/differential_transformer/differential_attention.py b/src/axolotl/integrations/differential_transformer/differential_attention.py index 2046f08bcf..1543981ea3 100644 --- a/src/axolotl/integrations/differential_transformer/differential_attention.py +++ b/src/axolotl/integrations/differential_transformer/differential_attention.py @@ -262,6 +262,7 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention): dtype: Data type for the layer parameters """ + # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] @@ -284,8 +285,8 @@ def forward( 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( - hidden_states=hidden_states, # pylint: disable=duplicate-code - attention_mask=attention_mask, # pylint: disable=duplicate-code + hidden_states=hidden_states, + attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, @@ -422,6 +423,7 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention): dtype: Data type for the layer parameters """ + # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] From a6b5a5e6c9c931a2ff93a5d96a0bccea5411fed3 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 17 Dec 2024 20:46:19 +0000 Subject: [PATCH 13/25] convert-differential-transformer test coverage --- .../integrations/convert_diff_transformer.py | 185 ----------------- .../convert_differential_transformer.py | 190 ++++++++++++++++++ src/axolotl/cli/main.py | 4 +- src/axolotl/common/cli.py | 2 +- .../differential_transformer/convert.py | 1 - tests/cli/integrations/__init__.py | 0 ...st_cli_convert_differential_transformer.py | 48 +++++ .../test_convert_differential_transformer.py | 127 ++++++++++++ 8 files changed, 368 insertions(+), 189 deletions(-) delete mode 100644 src/axolotl/cli/integrations/convert_diff_transformer.py create mode 100644 src/axolotl/cli/integrations/convert_differential_transformer.py create mode 100644 tests/cli/integrations/__init__.py create mode 100644 tests/cli/integrations/test_cli_convert_differential_transformer.py create mode 100644 tests/e2e/integrations/test_convert_differential_transformer.py diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py deleted file mode 100644 index 6eb00452b1..0000000000 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ /dev/null @@ -1,185 +0,0 @@ -"""CLI to convert a transformers model's attns to diff attns.""" -import logging -import warnings -from pathlib import Path -from time import time -from typing import Union - -import fire -import torch -import yaml -from colorama import Fore -from dotenv import load_dotenv -from transformers import HfArgumentParser - -from axolotl.cli import load_cfg, print_axolotl_text_art -from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer -from axolotl.integrations.differential_transformer.convert import ( - convert_to_diff_attention, -) - -LOG = logging.getLogger("axolotl.cli.convert_attention") - - -def test_inference(model, tokenizer, prompt="The quick brown fox"): - """Run test inference and return generation time""" - try: - inputs = tokenizer(prompt, return_tensors="pt") - inputs = { - k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() - } - - start = time() - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=20, - num_beams=1, - do_sample=False, - pad_token_id=tokenizer.pad_token_id, - use_cache=False, - ) - elapsed = time() - start - - generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) - LOG.info("Prompt: %s", prompt) - LOG.info("Generated: %s", generated_text) - LOG.info("Generation time: %.2fs", elapsed) - - return elapsed, generated_text - - except Exception as exc: - LOG.error("Inference failed: %s", str(exc)) - raise - - -def convert_diff_transformer(cfg, cli_args, config_path): - try: - # Load model and tokenizer - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - model.to(cfg.device, dtype=cfg.torch_dtype) - - # Log original model info - LOG.info( - "Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d", - model.config.hidden_size, - model.config.num_attention_heads, - ) - - # Test original model - if cli_args.debug: - LOG.info("Testing original model...") - orig_time, orig_text = test_inference(model, tokenizer) - - # Convert attention - LOG.info("Converting to differential attention...") - try: - model = convert_to_diff_attention( - model=model, - zero_init=cli_args.zero_init, - sublayer_norm=cli_args.sublayer_norm, - ) - model.to(cfg.device, dtype=cfg.torch_dtype) - except Exception as exc: - LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) - raise - - # Test converted model - if cli_args.debug: - LOG.info("Testing converted model...") - conv_time, conv_text = test_inference(model, tokenizer) - - # Save if requested - if cfg.output_dir: - # Save model and tokenizer - LOG.info("Saving converted model to %s", cfg.output_dir) - model.save_pretrained(cfg.output_dir) - tokenizer.save_pretrained(cfg.output_dir) - - # Modify config to reflect new path / differential attention - output_config_path = Path(cfg.output_dir) / "axolotl_config.yml" - LOG.info("Saving updated config to %s", output_config_path) - - with open(config_path, "r", encoding="utf-8") as file: - data = yaml.safe_load(file) or {} - - data["base_model"] = cfg.output_dir - data["diff_attention"] = True - - with open(output_config_path, "w", encoding="utf-8") as file: - yaml.dump(data, file) - else: - LOG.info("Not saving converted model to disk") - LOG.info("Pass --output-dir path/to/save to save model") - - if cli_args.debug: - LOG.info( - Fore.GREEN - + "Conversion successful!\n" - + f"Original generation time: {orig_time:.2f}s\n" - + f"Converted generation time: {conv_time:.2f}s" - + Fore.RESET - ) - - if orig_text == conv_text: - LOG.info( - Fore.GREEN - + "Generations match!\n" - + "Model generation:\n" - + "*" * 50 - + "\n" - + f"{orig_text}\n" - + "*" * 50 - + "\n" - + Fore.RESET - ) - else: - message = ( - "Generations do not match.\n" - + "Original generation:\n" - + "*" * 50 - + "\n" - + f"{orig_text}\n" - + "*" * 50 - + "\n" - + "Converted generation:\n" - + "*" * 50 - + "\n" - + f"{conv_text}\n" - + "*" * 50 - + "\n" - ) - - if cli_args.zero_init and not cli_args.sublayer_norm: - LOG.info(Fore.RED + message + Fore.RESET) - else: - LOG.info( - Fore.YELLOW - + message - + "However, this is expected since --zero-init" - + " and --no-sublayer-norm were not passed." - + Fore.RESET - ) - - return model - - except Exception as exc: - LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc)) - raise - - -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - print_axolotl_text_art() - - cfg = load_cfg(config, **kwargs) - parser = HfArgumentParser(ConvertDiffTransformerCliArgs) - cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) - - convert_diff_transformer(cfg, cli_args, config) - - -if __name__ == "__main__": - load_dotenv() - fire.Fire(do_cli) diff --git a/src/axolotl/cli/integrations/convert_differential_transformer.py b/src/axolotl/cli/integrations/convert_differential_transformer.py new file mode 100644 index 0000000000..8903da6d1e --- /dev/null +++ b/src/axolotl/cli/integrations/convert_differential_transformer.py @@ -0,0 +1,190 @@ +"""CLI to convert a transformers model's attns to diff attns.""" +import logging +import warnings +from pathlib import Path +from time import time +from typing import Union + +import fire +import torch +import yaml +from colorama import Fore +from dotenv import load_dotenv +from transformers import HfArgumentParser + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer +from axolotl.integrations.differential_transformer.convert import ( + convert_to_diff_attention, +) + +LOG = logging.getLogger(__name__) + + +def test_inference(model, tokenizer, prompt="The quick brown fox"): + """Run test inference and return generation time""" + try: + inputs = tokenizer(prompt, return_tensors="pt") + inputs = { + k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() + } + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) + elapsed = time() - start + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + LOG.info("Prompt: %s", prompt) + LOG.info("Generated: %s", generated_text) + LOG.info("Generation time: %.2fs", elapsed) + + return elapsed, generated_text + + except Exception as exc: + LOG.error("Inference failed: %s", str(exc)) + raise + + +def convert_differential_transformer(cfg, cli_args, config_path): + debug_info = {} + + # Load model and tokenizer + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model.to(cfg.device, dtype=cfg.torch_dtype) + + # Log original model info + LOG.info( + "Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d", + model.config.hidden_size, + model.config.num_attention_heads, + ) + + # Test original model + if cli_args.debug: + LOG.info("Testing original model...") + debug_info["orig_time"], debug_info["orig_text"] = test_inference( + model, tokenizer + ) + + # Convert attention + LOG.info("Converting to differential attention...") + try: + model = convert_to_diff_attention( + model=model, + zero_init=cli_args.zero_init, + sublayer_norm=cli_args.sublayer_norm, + ) + model.to(cfg.device, dtype=cfg.torch_dtype) + except Exception as exc: + LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) + raise + + # Test converted model + if cli_args.debug: + LOG.info("Testing converted model...") + debug_info["conv_time"], debug_info["conv_text"] = test_inference( + model, tokenizer + ) + + # Save if requested + if cfg.output_dir: + # Save model and tokenizer + LOG.info("Saving converted model to %s", cfg.output_dir) + model.save_pretrained(cfg.output_dir) + tokenizer.save_pretrained(cfg.output_dir) + + # Modify config to reflect new path / differential attention + output_config_path = Path(cfg.output_dir) / "axolotl_config.yml" + LOG.info("Saving updated config to %s", output_config_path) + + with open(config_path, "r", encoding="utf-8") as file: + data = yaml.safe_load(file) or {} + + data["base_model"] = cfg.output_dir + data["diff_attention"] = True + + with open(output_config_path, "w", encoding="utf-8") as file: + yaml.dump(data, file) + else: + LOG.info("Not saving converted model to disk") + LOG.info("Pass --output-dir path/to/save to save model") + + if cli_args.debug: + LOG.info( + Fore.GREEN + + "Conversion successful!\n" + + f"Original generation time: {debug_info['orig_time']:.2f}s\n" + + f"Converted generation time: {debug_info['conv_time']:.2f}s" + + Fore.RESET + ) + + if debug_info["orig_text"] == debug_info["conv_text"]: + LOG.info( + Fore.GREEN + + "Generations match!\n" + + "Model generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + Fore.RESET + ) + debug_info["generations_match"] = True + else: + message = ( + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['conv_text']}\n" + + "*" * 50 + + "\n" + ) + debug_info["generations_match"] = False + + if cli_args.zero_init and not cli_args.sublayer_norm: + LOG.info(Fore.RED + message + Fore.RESET) + debug_info["match_expected"] = True + else: + LOG.info( + Fore.YELLOW + + message + + "However, this is expected since --zero-init" + + " and --no-sublayer-norm were not passed." + + Fore.RESET + ) + debug_info["match_expected"] = False + + return model, debug_info + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + print_axolotl_text_art() + + cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(ConvertDiffTransformerCliArgs) + cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + convert_differential_transformer(cfg, cli_args, config) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index e922d37e99..c3f55dc026 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -249,11 +249,11 @@ def merge_lora( @click.argument("config", type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(ConvertDiffTransformerCliArgs) @add_options_from_config(AxolotlInputConfig) -def convert_diff_transformer(config: str, **kwargs): +def convert_differential_transformer(config: str, **kwargs): """Convert model attention layers to differential attention layers.""" kwargs = {k: v for k, v in kwargs.items() if v is not None} - from axolotl.cli.integrations.convert_diff_transformer import do_cli + from axolotl.cli.integrations.convert_differential_transformer import do_cli do_cli(config=config, **kwargs) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 9c921e5640..2d6a5bb31d 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -57,7 +57,7 @@ class EvaluateCliArgs: @dataclass class ConvertDiffTransformerCliArgs: """ - dataclass with arguments for convert-diff-transformer CLI + dataclass with arguments for convert-differential-transformer CLI """ debug: bool = field(default=False) diff --git a/src/axolotl/integrations/differential_transformer/convert.py b/src/axolotl/integrations/differential_transformer/convert.py index 5620ad1995..ce3773037a 100644 --- a/src/axolotl/integrations/differential_transformer/convert.py +++ b/src/axolotl/integrations/differential_transformer/convert.py @@ -17,7 +17,6 @@ LlamaDifferentialSdpaAttention, ) -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) ATTENTION_MAPPING = { diff --git a/tests/cli/integrations/__init__.py b/tests/cli/integrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/cli/integrations/test_cli_convert_differential_transformer.py b/tests/cli/integrations/test_cli_convert_differential_transformer.py new file mode 100644 index 0000000000..cd2a464c60 --- /dev/null +++ b/tests/cli/integrations/test_cli_convert_differential_transformer.py @@ -0,0 +1,48 @@ +"""Tests for convert-differential-transformer CLI command.""" + +from pathlib import Path +from unittest.mock import patch + +from axolotl.cli.main import cli + + +def test_cli_validation(cli_runner): + """Test CLI validation for a command. + + Args: + cli_runner: CLI runner fixture + """ + # Test missing config file + result = cli_runner.invoke(cli, ["convert-differential-transformer"]) + assert result.exit_code != 0 + assert "Error: Missing argument 'CONFIG'." in result.output + + # Test non-existent config file + result = cli_runner.invoke( + cli, ["convert-differential-transformer", "nonexistent.yml"] + ) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + +def test_basic_execution(cli_runner, tmp_path: Path, valid_test_config: str): + """Test basic execution. + + Args: + cli_runner: CLI runner fixture + tmp_path: Temporary path fixture + valid_test_config: Valid config fixture + """ + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch( + "axolotl.cli.integrations.convert_differential_transformer.do_cli" + ) as mock_do_cli: + result = cli_runner.invoke( + cli, ["convert-differential-transformer", str(config_path)] + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) diff --git a/tests/e2e/integrations/test_convert_differential_transformer.py b/tests/e2e/integrations/test_convert_differential_transformer.py new file mode 100644 index 0000000000..da3aac11ad --- /dev/null +++ b/tests/e2e/integrations/test_convert_differential_transformer.py @@ -0,0 +1,127 @@ +"""End-to-end tests for differential transformer conversion.""" +# pylint: disable=redefined-outer-name + +from pathlib import Path +from typing import Optional + +import pytest +import yaml + +from axolotl.cli import load_cfg +from axolotl.cli.integrations.convert_differential_transformer import ( + convert_differential_transformer, +) +from axolotl.common.cli import ConvertDiffTransformerCliArgs + + +@pytest.fixture() +def base_config(): + """Basic config for testing.""" + return { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "gradient_accumulation_steps": 1, + "learning_rate": 1e-4, + "val_set_size": 0.1, + "micro_batch_size": 1, + "sequence_len": 2048, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + } + + +def test_conversion_cli_basic(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + # Load config the same way do_cli does + cfg = load_cfg(str(config_path)) + + # Create CLI args + cli_args = ConvertDiffTransformerCliArgs() + + # Call convert_differential_transformer directly + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_cli_debug(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + # Load config the same way do_cli does + cfg = load_cfg(str(config_path)) + + # Create CLI args + cli_args = ConvertDiffTransformerCliArgs(debug=True) + + # Call convert_differential_transformer directly + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info["generations_match"] + assert not debug_info["match_expected"] + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_cli_reproduce(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +@pytest.mark.parametrize("attention", ["sdp_attention", "flash_attention"]) +def test_conversion_cli_repoduce_attentions( + tmp_path: Path, base_config, attention: Optional[str] +): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() From 6a9af88bd0eb582843f9f048b7ed6523604ae789 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 18 Dec 2024 01:26:41 +0000 Subject: [PATCH 14/25] plugin implementation --- .gitignore | 3 ++ model-out/eval_summary.csv | 6 --- outputs | 1 - src/axolotl/cli/evaluate.py | 6 +-- .../convert_differential_transformer.py | 9 ++-- src/axolotl/integrations/config.py | 3 ++ .../differential_transformer/README.md | 10 ++++ .../differential_transformer/__init__.py | 25 +++++++++ .../differential_transformer/args.py | 14 +++++ .../differential_transformer/convert.py | 2 +- .../config/models/input/v0_4_1/__init__.py | 2 - src/axolotl/utils/models.py | 15 ++---- .../test_convert_differential_transformer.py | 52 +++++++++++++++++-- 13 files changed, 118 insertions(+), 30 deletions(-) delete mode 100644 model-out/eval_summary.csv delete mode 120000 outputs create mode 100644 src/axolotl/integrations/differential_transformer/README.md create mode 100644 src/axolotl/integrations/differential_transformer/args.py diff --git a/.gitignore b/.gitignore index 7b604d88c7..4d7ba15a1b 100644 --- a/.gitignore +++ b/.gitignore @@ -186,3 +186,6 @@ out/ # vim *.swp + +# symlinked to axolotl-artifacts in docker containers +outputs diff --git a/model-out/eval_summary.csv b/model-out/eval_summary.csv deleted file mode 100644 index ccbe73358c..0000000000 --- a/model-out/eval_summary.csv +++ /dev/null @@ -1,6 +0,0 @@ -metric,training,validation -loss,1.8773103952407837,1.915901780128479 -model_preparation_time,0.0051,0.0051 -runtime,89.7635,8.9565 -samples_per_second,20.053,22.33 -steps_per_second,20.053,22.33 diff --git a/outputs b/outputs deleted file mode 120000 index be3c4a823f..0000000000 --- a/outputs +++ /dev/null @@ -1 +0,0 @@ -/workspace/data/axolotl-artifacts \ No newline at end of file diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 8e99d6f4b1..655f3782fd 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -3,7 +3,7 @@ """ import logging from pathlib import Path -from typing import Union +from typing import Dict, Union import fire from dotenv import load_dotenv @@ -23,7 +23,7 @@ LOG = logging.getLogger("axolotl.cli.evaluate") -def do_evaluate(cfg, cli_args) -> None: +def do_evaluate(cfg, cli_args) -> Dict[str, float]: # pylint: disable=duplicate-code print_axolotl_text_art() check_accelerate_default_config() @@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: diff --git a/src/axolotl/cli/integrations/convert_differential_transformer.py b/src/axolotl/cli/integrations/convert_differential_transformer.py index 8903da6d1e..a687a3f7cb 100644 --- a/src/axolotl/cli/integrations/convert_differential_transformer.py +++ b/src/axolotl/cli/integrations/convert_differential_transformer.py @@ -15,7 +15,7 @@ from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer from axolotl.integrations.differential_transformer.convert import ( - convert_to_diff_attention, + convert_to_differential_attention, ) LOG = logging.getLogger(__name__) @@ -79,7 +79,7 @@ def convert_differential_transformer(cfg, cli_args, config_path): # Convert attention LOG.info("Converting to differential attention...") try: - model = convert_to_diff_attention( + model = convert_to_differential_attention( model=model, zero_init=cli_args.zero_init, sublayer_norm=cli_args.sublayer_norm, @@ -111,7 +111,10 @@ def convert_differential_transformer(cfg, cli_args, config_path): data = yaml.safe_load(file) or {} data["base_model"] = cfg.output_dir - data["diff_attention"] = True + data["differential_attention"] = True + data["plugins"] = [ + "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin" + ] with open(output_config_path, "w", encoding="utf-8") as file: yaml.dump(data, file) diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index b4ffd6758f..f7d35fcf89 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -43,10 +43,12 @@ def merge_input_args(): input_args: List[str] = plugin_manager.get_input_args() plugin_classes = [] dynamic_input = "" + for plugin_args in input_args: plugin_module, plugin_cls = plugin_args.rsplit(".", 1) dynamic_input += f"from {plugin_module} import {plugin_cls}\n" plugin_classes.append(plugin_cls) + if dynamic_input: dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" @@ -62,4 +64,5 @@ def merge_input_args(): "AxolotlConfigWCapabilities" ] return AxolotlConfigWCapabilities, AxolotlInputConfig + return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase diff --git a/src/axolotl/integrations/differential_transformer/README.md b/src/axolotl/integrations/differential_transformer/README.md new file mode 100644 index 0000000000..f7bd74cbdb --- /dev/null +++ b/src/axolotl/integrations/differential_transformer/README.md @@ -0,0 +1,10 @@ +# Differential Transformer + +### Usage + +```yaml +plugins: + - axolotl.integrations.differential_transformer.DifferentialTransformerPlugin + +differential_attention: true +``` diff --git a/src/axolotl/integrations/differential_transformer/__init__.py b/src/axolotl/integrations/differential_transformer/__init__.py index e69de29bb2..63741793c4 100644 --- a/src/axolotl/integrations/differential_transformer/__init__.py +++ b/src/axolotl/integrations/differential_transformer/__init__.py @@ -0,0 +1,25 @@ +"""Definition of differential transformer plugin.""" + +import logging + +from axolotl.integrations.base import BasePlugin + +LOG = logging.getLogger(__name__) + + +class DifferentialTransformerPlugin(BasePlugin): + """ + Plugin for differential transformer integration with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs" + + def pre_model_load(self, cfg): + """Apply differential attention patch before model loading if enabled.""" + if cfg.differential_attention: + from axolotl.monkeypatch.attention.differential import ( + patch_llama_attention_classes, + ) + + patch_llama_attention_classes() diff --git a/src/axolotl/integrations/differential_transformer/args.py b/src/axolotl/integrations/differential_transformer/args.py new file mode 100644 index 0000000000..bd6e01520f --- /dev/null +++ b/src/axolotl/integrations/differential_transformer/args.py @@ -0,0 +1,14 @@ +"""Module for handling differential transfomer input arguments.""" + +import logging +from typing import Optional + +from pydantic import BaseModel + +LOG = logging.getLogger(__name__) + + +class DifferentialTransformerArgs(BaseModel): + """Input args for differential transformer.""" + + differential_attention: Optional[bool] = None diff --git a/src/axolotl/integrations/differential_transformer/convert.py b/src/axolotl/integrations/differential_transformer/convert.py index ce3773037a..d516f94768 100644 --- a/src/axolotl/integrations/differential_transformer/convert.py +++ b/src/axolotl/integrations/differential_transformer/convert.py @@ -80,7 +80,7 @@ def copy_attention_weights( ) -def convert_to_diff_attention( +def convert_to_differential_attention( model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True ) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index d98cb3cb68..0781c67989 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -726,8 +726,6 @@ class Config: eager_attention: Optional[bool] = None - diff_attention: Optional[bool] = None - unsloth_cross_entropy_loss: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e98e9f31b4..8c8bd0e38f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -444,13 +444,6 @@ def apply_patches(self) -> None: patch_mistral_cross_entropy() - if self.cfg.diff_attention: - from axolotl.monkeypatch.attention.differential import ( - patch_llama_attention_classes, - ) - - patch_llama_attention_classes() - def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: @@ -721,7 +714,7 @@ def set_attention_config(self) -> None: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - if self.cfg.diff_attention: + if self.cfg.differential_attention: self.model_kwargs[ "attn_implementation" ] = "differential_flash_attention_2" @@ -734,7 +727,7 @@ def set_attention_config(self) -> None: "flash_attention_2" ) elif self.cfg.sdp_attention: - if self.cfg.diff_attention: + if self.cfg.differential_attention: self.model_kwargs["attn_implementation"] = "differential_sdpa" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_sdpa" @@ -745,7 +738,7 @@ def set_attention_config(self) -> None: "sdpa" ) elif self.cfg.eager_attention: - if self.cfg.diff_attention: + if self.cfg.differential_attention: self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_eager" @@ -755,7 +748,7 @@ def set_attention_config(self) -> None: self.model_config._attn_implementation = ( # pylint: disable=protected-access "eager" ) - elif self.cfg.diff_attention: + elif self.cfg.differential_attention: self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_eager" diff --git a/tests/e2e/integrations/test_convert_differential_transformer.py b/tests/e2e/integrations/test_convert_differential_transformer.py index da3aac11ad..9ddcf57674 100644 --- a/tests/e2e/integrations/test_convert_differential_transformer.py +++ b/tests/e2e/integrations/test_convert_differential_transformer.py @@ -6,12 +6,14 @@ import pytest import yaml +from pytest import approx from axolotl.cli import load_cfg +from axolotl.cli.evaluate import do_evaluate from axolotl.cli.integrations.convert_differential_transformer import ( convert_differential_transformer, ) -from axolotl.common.cli import ConvertDiffTransformerCliArgs +from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs @pytest.fixture() @@ -19,9 +21,12 @@ def base_config(): """Basic config for testing.""" return { "base_model": "HuggingFaceTB/SmolLM2-135M", + "plugins": [ + "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin", + ], "datasets": [ { - "path": "mhenrichsen/alpaca_2k_test", + "path": "axolotl-ai-co/alpaca_100_test", "type": "alpaca", }, ], @@ -103,7 +108,9 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config): assert (output_dir / "axolotl_config.yml").exists() -@pytest.mark.parametrize("attention", ["sdp_attention", "flash_attention"]) +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) def test_conversion_cli_repoduce_attentions( tmp_path: Path, base_config, attention: Optional[str] ): @@ -125,3 +132,42 @@ def test_conversion_cli_repoduce_attentions( assert (output_dir / "model.safetensors").exists() assert (output_dir / "config.json").exists() assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_and_eval_cli(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + eval_cfg = load_cfg(str(output_dir)) + eval_cli_args = EvaluateCliArgs() + all_metrics = do_evaluate(eval_cfg, eval_cli_args) + + assert list(all_metrics.keys()) == [ + "train_loss", + "train_model_preparation_time", + "train_runtime", + "train_samples_per_second", + "train_steps_per_second", + "eval_loss", + "eval_model_preparation_time", + "eval_runtime", + "eval_samples_per_second", + "eval_steps_per_second", + ] + assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4) + assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4) From b7294d43ce58eaa29476cc87229dee66fb8f40f2 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 18 Dec 2024 01:38:56 +0000 Subject: [PATCH 15/25] fixes post-rebase --- src/axolotl/cli/main.py | 3 +++ src/axolotl/evaluate.py | 8 ++------ src/axolotl/train.py | 4 ++-- src/axolotl/utils/trainer.py | 11 ----------- 4 files changed, 7 insertions(+), 19 deletions(-) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index c3f55dc026..c7549bd4a0 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -82,6 +82,9 @@ def evaluate(config: str, accelerate: bool, **kwargs): """Evaluate a model.""" kwargs = {k: v for k, v in kwargs.items() if v is not None} + # Enable expandable segments for cuda allocation to improve VRAM usage + set_pytorch_cuda_alloc_conf() + if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] if config: diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index bc1799960c..1c62fc6ab7 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -12,10 +12,9 @@ from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta -from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_processor -from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer +from axolotl.utils.trainer import setup_trainer project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -62,6 +61,7 @@ def evaluate_dataset( return metrics +# pylint: disable=duplicate-code def evaluate( *, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta ) -> Dict[str, float]: @@ -79,10 +79,6 @@ def evaluate( - The tokenizer - Dictionary of evaluation metrics """ - # pylint: disable=duplicate-code - # Enable expandable segments for cuda allocation to improve VRAM usage - set_pytorch_cuda_alloc_conf() - # Load model LOG.debug("loading model for evaluation...") diff --git a/src/axolotl/train.py b/src/axolotl/train.py index c0aac93244..a74ecc2ec3 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -27,7 +27,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer -from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer +from axolotl.utils.trainer import setup_trainer try: from optimum.bettertransformer import BetterTransformer @@ -88,7 +88,7 @@ def train( ) resume_from_checkpoint = cfg.resume_from_checkpoint - # Load the model + # Load the model and tokenizer msg = "loading model" if cfg.adapter: msg += " and peft_config..." diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index fd09b3eb67..32e54c9a86 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -512,17 +512,6 @@ def prepare_opinionated_env(cfg): os.environ["TOKENIZERS_PARALLELISM"] = "false" -def set_pytorch_cuda_alloc_conf(): - """Set up CUDA allocation config if using PyTorch >= 2.2""" - torch_version = torch.__version__.split(".") - torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) - if torch_major == 2 and torch_minor >= 2: - if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: - os.environ[ - "PYTORCH_CUDA_ALLOC_CONF" - ] = "expandable_segments:True,roundup_power2_divisions:16" - - def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ): From c57d21e5c96349f4a199d35233d3d12cd44c490c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 18 Dec 2024 03:30:35 +0000 Subject: [PATCH 16/25] isolating problematic test --- cicd/multigpu.py | 4 +- .../monkeypatch/attention/differential.py | 1 - src/axolotl/utils/models.py | 1 - .../__init__.py | 0 .../conftest.py | 28 ++++++++ .../test_convert_and_evaluate.py | 53 ++++++++++++++ .../test_convert_differential_transformer.py | 69 +------------------ 7 files changed, 85 insertions(+), 71 deletions(-) create mode 100644 tests/e2e/integrations/convert_differential_transformer/__init__.py create mode 100644 tests/e2e/integrations/convert_differential_transformer/conftest.py create mode 100644 tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py rename tests/e2e/integrations/{ => convert_differential_transformer}/test_convert_differential_transformer.py (62%) diff --git a/cicd/multigpu.py b/cicd/multigpu.py index 0ea4c8cc11..511e31c8e5 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -1,6 +1,6 @@ """ - modal application to run axolotl gpu tests in Modal - """ +modal application to run axolotl gpu tests in Modal +""" # pylint: disable=duplicate-code import os diff --git a/src/axolotl/monkeypatch/attention/differential.py b/src/axolotl/monkeypatch/attention/differential.py index 36e3821af6..a07b629b6b 100644 --- a/src/axolotl/monkeypatch/attention/differential.py +++ b/src/axolotl/monkeypatch/attention/differential.py @@ -12,7 +12,6 @@ def patch_llama_attention_classes(): """Patch transformers to support differential attention""" - # Add our attention class to the registry LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8c8bd0e38f..6eaa020da0 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -843,7 +843,6 @@ def _configure_zero3_memory_efficient_loading(): if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config - # self.model._attn_implementation_autoset = False self.model = self.AutoModelLoader.from_pretrained( self.base_model, config=self.model_config, diff --git a/tests/e2e/integrations/convert_differential_transformer/__init__.py b/tests/e2e/integrations/convert_differential_transformer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/e2e/integrations/convert_differential_transformer/conftest.py b/tests/e2e/integrations/convert_differential_transformer/conftest.py new file mode 100644 index 0000000000..17a424ddbe --- /dev/null +++ b/tests/e2e/integrations/convert_differential_transformer/conftest.py @@ -0,0 +1,28 @@ +"""Shared fixtures for differential transformer conversion tests.""" + +import pytest + + +@pytest.fixture() +def base_config(): + """Basic config for testing.""" + return { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "plugins": [ + "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin", + ], + "datasets": [ + { + "path": "axolotl-ai-co/alpaca_100_test", + "type": "alpaca", + }, + ], + "gradient_accumulation_steps": 1, + "learning_rate": 1e-4, + "val_set_size": 0.1, + "micro_batch_size": 1, + "sequence_len": 2048, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + } diff --git a/tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py b/tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py new file mode 100644 index 0000000000..1cf569693c --- /dev/null +++ b/tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py @@ -0,0 +1,53 @@ +"""End-to-end tests for differential transformer conversion and evaluation.""" +# pylint: disable=duplicate-code + +from pathlib import Path + +import yaml +from pytest import approx + +from axolotl.cli import load_cfg +from axolotl.cli.evaluate import do_evaluate +from axolotl.cli.integrations.convert_differential_transformer import ( + convert_differential_transformer, +) +from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs + + +def test_conversion_and_eval_cli(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + eval_cfg = load_cfg(str(output_dir)) + eval_cli_args = EvaluateCliArgs() + all_metrics = do_evaluate(eval_cfg, eval_cli_args) + + assert list(all_metrics.keys()) == [ + "train_loss", + "train_model_preparation_time", + "train_runtime", + "train_samples_per_second", + "train_steps_per_second", + "eval_loss", + "eval_model_preparation_time", + "eval_runtime", + "eval_samples_per_second", + "eval_steps_per_second", + ] + assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4) + assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4) diff --git a/tests/e2e/integrations/test_convert_differential_transformer.py b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py similarity index 62% rename from tests/e2e/integrations/test_convert_differential_transformer.py rename to tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py index 9ddcf57674..4349287bdc 100644 --- a/tests/e2e/integrations/test_convert_differential_transformer.py +++ b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py @@ -1,44 +1,18 @@ """End-to-end tests for differential transformer conversion.""" # pylint: disable=redefined-outer-name +# pylint: disable=duplicate-code from pathlib import Path from typing import Optional import pytest import yaml -from pytest import approx from axolotl.cli import load_cfg -from axolotl.cli.evaluate import do_evaluate from axolotl.cli.integrations.convert_differential_transformer import ( convert_differential_transformer, ) -from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs - - -@pytest.fixture() -def base_config(): - """Basic config for testing.""" - return { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "plugins": [ - "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin", - ], - "datasets": [ - { - "path": "axolotl-ai-co/alpaca_100_test", - "type": "alpaca", - }, - ], - "gradient_accumulation_steps": 1, - "learning_rate": 1e-4, - "val_set_size": 0.1, - "micro_batch_size": 1, - "sequence_len": 2048, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - } +from axolotl.common.cli import ConvertDiffTransformerCliArgs def test_conversion_cli_basic(tmp_path: Path, base_config): @@ -132,42 +106,3 @@ def test_conversion_cli_repoduce_attentions( assert (output_dir / "model.safetensors").exists() assert (output_dir / "config.json").exists() assert (output_dir / "axolotl_config.yml").exists() - - -def test_conversion_and_eval_cli(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False - ) - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - eval_cfg = load_cfg(str(output_dir)) - eval_cli_args = EvaluateCliArgs() - all_metrics = do_evaluate(eval_cfg, eval_cli_args) - - assert list(all_metrics.keys()) == [ - "train_loss", - "train_model_preparation_time", - "train_runtime", - "train_samples_per_second", - "train_steps_per_second", - "eval_loss", - "eval_model_preparation_time", - "eval_runtime", - "eval_samples_per_second", - "eval_steps_per_second", - ] - assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4) - assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4) From 513b262b4f545bcecadce30df06f387422b6d3b5 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 18 Dec 2024 05:56:29 +0000 Subject: [PATCH 17/25] adding split_heads argument for retaining original (Q, K) dimensionanlity --- .../convert_differential_transformer.py | 14 +- src/axolotl/common/cli.py | 1 + .../differential_transformer/convert.py | 11 +- .../differential_attention.py | 174 +++++++++++++----- .../test_convert_differential_transformer.py | 23 +++ 5 files changed, 173 insertions(+), 50 deletions(-) diff --git a/src/axolotl/cli/integrations/convert_differential_transformer.py b/src/axolotl/cli/integrations/convert_differential_transformer.py index a687a3f7cb..b50dd43dd8 100644 --- a/src/axolotl/cli/integrations/convert_differential_transformer.py +++ b/src/axolotl/cli/integrations/convert_differential_transformer.py @@ -14,9 +14,7 @@ from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer -from axolotl.integrations.differential_transformer.convert import ( - convert_to_differential_attention, -) +from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn LOG = logging.getLogger(__name__) @@ -78,11 +76,19 @@ def convert_differential_transformer(cfg, cli_args, config_path): # Convert attention LOG.info("Converting to differential attention...") + if cli_args.split_heads and cli_args.zero_init: + LOG.warning( + Fore.YELLOW + + "Warning: Using split_heads with zero_init is not recommended; " + + "split_heads will preclude the effects of zero_init" + + Fore.RESET + ) try: - model = convert_to_differential_attention( + model = convert_to_diff_attn( model=model, zero_init=cli_args.zero_init, sublayer_norm=cli_args.sublayer_norm, + split_heads=cli_args.split_heads, ) model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 2d6a5bb31d..c51c4e2ab9 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -63,6 +63,7 @@ class ConvertDiffTransformerCliArgs: debug: bool = field(default=False) zero_init: bool = field(default=False) sublayer_norm: bool = field(default=True) + split_heads: bool = field(default=False) def load_model_and_tokenizer( diff --git a/src/axolotl/integrations/differential_transformer/convert.py b/src/axolotl/integrations/differential_transformer/convert.py index d516f94768..4beaea7ae9 100644 --- a/src/axolotl/integrations/differential_transformer/convert.py +++ b/src/axolotl/integrations/differential_transformer/convert.py @@ -80,14 +80,18 @@ def copy_attention_weights( ) -def convert_to_differential_attention( - model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True +def convert_to_diff_attn( + model: PreTrainedModel, + zero_init: bool = False, + sublayer_norm: bool = True, + split_heads: bool = True, ) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" layer_idx = 0 # Set sublayer norm as config on the model. model.config.sublayer_norm = sublayer_norm + model.config.split_heads = split_heads def convert_module(module): nonlocal layer_idx @@ -111,7 +115,8 @@ def convert_module(module): # Copy weights from old attention to new attention new_attention.to(child.q_proj.weight.device) - copy_attention_weights(child, new_attention, zero_init=zero_init) + if not split_heads: + copy_attention_weights(child, new_attention, zero_init=zero_init) # Replace the layer setattr(module, name, new_attention) diff --git a/src/axolotl/integrations/differential_transformer/differential_attention.py b/src/axolotl/integrations/differential_transformer/differential_attention.py index 1543981ea3..58d4b94ecb 100644 --- a/src/axolotl/integrations/differential_transformer/differential_attention.py +++ b/src/axolotl/integrations/differential_transformer/differential_attention.py @@ -70,26 +70,51 @@ def __init__( self.hidden_size = config.hidden_size self.base_num_heads = config.num_attention_heads self.base_num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads + + if config.split_heads: + self.head_dim = config.hidden_size // config.num_attention_heads // 2 + else: + self.head_dim = config.hidden_size // config.num_attention_heads self.layer_idx = layer_idx self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - - # For Q1 and Q2 - self.q_proj = nn.Linear( - self.hidden_size, - self.hidden_size * 2, - bias=False, - ) - - # For K1 and K2 - self.k_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, - bias=False, - ) + self.split_heads = config.split_heads + + if config.split_heads: + # Split heads mode + assert ( + self.base_num_heads % 2 == 0 + ), "Number of heads must be even for splitting" + self.heads_per_component = self.base_num_heads // 2 + + # Single projections + self.q_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads, + bias=False, + ) + else: + # Double projection mode + self.heads_per_component = self.base_num_heads + + # Double-sized projections + self.q_proj = nn.Linear( + self.hidden_size, + self.hidden_size * 2, + bias=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, + bias=False, + ) # Single V projection self.v_proj = nn.Linear( @@ -125,8 +150,14 @@ def __init__( self.rotary_emb = LlamaRotaryEmbedding(config=config) sublayer_norm = getattr(config, "sublayer_norm", True) + + if self.split_heads: + subln_dim = 2 * self.head_dim + else: + subln_dim = self.head_dim + self.subln = ( - LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5) + LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5) if sublayer_norm else nn.Identity() ) @@ -167,7 +198,10 @@ def forward( k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Reshape V - v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + if self.split_heads: + v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) + else: + v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Apply rotary embeddings if position_embeddings is None: @@ -177,6 +211,10 @@ def forward( else: cos, sin = position_embeddings + if self.split_heads: + cos, _ = cos.chunk(2, dim=2) + sin, _ = sin.chunk(2, dim=2) + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) @@ -192,8 +230,6 @@ def forward( v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) # Calculate attention scores for both parts - # NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor - # instead of sqrt(head_dim). This could be set on the class as `self.scaling`. attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) @@ -307,13 +343,18 @@ def forward( k1, k2 = kp.chunk(2, dim=-1) # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + # Reshape V - v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + if self.split_heads: + v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) + else: + v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Apply rotary embeddings if position_embeddings is None: @@ -323,6 +364,10 @@ def forward( else: cos, sin = position_embeddings + if self.split_heads: + cos, _ = cos.chunk(2, dim=2) + sin, _ = sin.chunk(2, dim=2) + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) @@ -468,13 +513,18 @@ def forward( k1, k2 = kp.chunk(2, dim=-1) # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + # Reshape V - v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + if self.split_heads: + v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) + else: + v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Apply rotary embeddings if position_embeddings is None: @@ -484,6 +534,10 @@ def forward( else: cos, sin = position_embeddings + if self.split_heads: + cos, _ = cos.chunk(2, dim=2) + sin, _ = sin.chunk(2, dim=2) + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) @@ -506,20 +560,54 @@ def forward( # Calculate attention using Flash Attention dropout_p = self.attention_dropout if self.training else 0.0 - attn1 = flash_attn_func( - q1, - k1, - v, - dropout_p=dropout_p, - causal=True, - ) - attn2 = flash_attn_func( - q2, - k2, - v, - dropout_p=dropout_p, - causal=True, - ) + if self.split_heads: + v1, v2 = v.chunk(2, dim=-1) + attn11 = flash_attn_func( + q1, + k1, + v1, + dropout_p=dropout_p, + causal=True, + ) + attn12 = flash_attn_func( + q1, + k1, + v2, + dropout_p=dropout_p, + causal=True, + ) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = flash_attn_func( + q2, + k2, + v1, + dropout_p=dropout_p, + causal=True, + ) + attn22 = flash_attn_func( + q2, + k2, + v2, + dropout_p=dropout_p, + causal=True, + ) + attn2 = torch.cat([attn21, attn22], dim=-1) + else: + attn1 = flash_attn_func( + q1, + k1, + v, + dropout_p=dropout_p, + causal=True, + ) + attn2 = flash_attn_func( + q2, + k2, + v, + dropout_p=dropout_p, + causal=True, + ) attn1 = attn1.transpose(1, 2) attn2 = attn2.transpose(1, 2) diff --git a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py index 4349287bdc..84e5fdaa15 100644 --- a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py +++ b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py @@ -106,3 +106,26 @@ def test_conversion_cli_repoduce_attentions( assert (output_dir / "model.safetensors").exists() assert (output_dir / "config.json").exists() assert (output_dir / "axolotl_config.yml").exists() + + +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) +def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): + output_dir = tmp_path / "converted" + base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is False + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() From 313265fef4cba44aabbc061e0a4798b197b44b79 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 18 Dec 2024 19:36:23 +0000 Subject: [PATCH 18/25] moving tests around for flash_attn install --- .../differential_attention.py | 6 +-- tests/cli/conftest.py | 1 + tests/cli/integrations/__init__.py | 0 ...st_cli_convert_differential_transformer.py | 48 ------------------- tests/cli/test_cli_fetch.py | 1 + tests/cli/test_cli_inference.py | 1 + tests/cli/test_cli_interface.py | 1 + tests/cli/test_cli_merge_lora.py | 1 + .../test_cli_merge_sharded_fsdp_weights.py | 1 + tests/cli/test_cli_preprocess.py | 1 + tests/cli/test_cli_shard.py | 1 + tests/cli/test_cli_version.py | 1 + tests/cli/test_utils.py | 1 + .../conftest.py | 6 +++ .../test_convert_differential_transformer.py | 34 ++++++++++++- 15 files changed, 52 insertions(+), 52 deletions(-) delete mode 100644 tests/cli/integrations/__init__.py delete mode 100644 tests/cli/integrations/test_cli_convert_differential_transformer.py diff --git a/src/axolotl/integrations/differential_transformer/differential_attention.py b/src/axolotl/integrations/differential_transformer/differential_attention.py index 58d4b94ecb..af7473436c 100644 --- a/src/axolotl/integrations/differential_transformer/differential_attention.py +++ b/src/axolotl/integrations/differential_transformer/differential_attention.py @@ -84,9 +84,9 @@ def __init__( if config.split_heads: # Split heads mode - assert ( - self.base_num_heads % 2 == 0 - ), "Number of heads must be even for splitting" + # assert ( + # self.base_num_heads % 2 == 0 + # ), "Number of heads must be even for splitting" self.heads_per_component = self.base_num_heads // 2 # Single projections diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 78b090e19e..d360e29d6b 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -1,4 +1,5 @@ """Shared pytest fixtures for cli module.""" + import pytest from click.testing import CliRunner diff --git a/tests/cli/integrations/__init__.py b/tests/cli/integrations/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/cli/integrations/test_cli_convert_differential_transformer.py b/tests/cli/integrations/test_cli_convert_differential_transformer.py deleted file mode 100644 index cd2a464c60..0000000000 --- a/tests/cli/integrations/test_cli_convert_differential_transformer.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Tests for convert-differential-transformer CLI command.""" - -from pathlib import Path -from unittest.mock import patch - -from axolotl.cli.main import cli - - -def test_cli_validation(cli_runner): - """Test CLI validation for a command. - - Args: - cli_runner: CLI runner fixture - """ - # Test missing config file - result = cli_runner.invoke(cli, ["convert-differential-transformer"]) - assert result.exit_code != 0 - assert "Error: Missing argument 'CONFIG'." in result.output - - # Test non-existent config file - result = cli_runner.invoke( - cli, ["convert-differential-transformer", "nonexistent.yml"] - ) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - - -def test_basic_execution(cli_runner, tmp_path: Path, valid_test_config: str): - """Test basic execution. - - Args: - cli_runner: CLI runner fixture - tmp_path: Temporary path fixture - valid_test_config: Valid config fixture - """ - config_path = tmp_path / "config.yml" - config_path.write_text(valid_test_config) - - with patch( - "axolotl.cli.integrations.convert_differential_transformer.do_cli" - ) as mock_do_cli: - result = cli_runner.invoke( - cli, ["convert-differential-transformer", str(config_path)] - ) - assert result.exit_code == 0 - - mock_do_cli.assert_called_once() - assert mock_do_cli.call_args.kwargs["config"] == str(config_path) diff --git a/tests/cli/test_cli_fetch.py b/tests/cli/test_cli_fetch.py index 0df87b0299..f06f067173 100644 --- a/tests/cli/test_cli_fetch.py +++ b/tests/cli/test_cli_fetch.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI fetch command.""" + from unittest.mock import patch from axolotl.cli.main import fetch diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py index 7cb163d255..b8effa3d20 100644 --- a/tests/cli/test_cli_inference.py +++ b/tests/cli/test_cli_inference.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI inference command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py index ed8335b766..8b5fec17f2 100644 --- a/tests/cli/test_cli_interface.py +++ b/tests/cli/test_cli_interface.py @@ -1,4 +1,5 @@ """General pytest tests for axolotl.cli.main interface.""" + from axolotl.cli.main import build_command, cli diff --git a/tests/cli/test_cli_merge_lora.py b/tests/cli/test_cli_merge_lora.py index 165a64e98c..aac0167603 100644 --- a/tests/cli/test_cli_merge_lora.py +++ b/tests/cli/test_cli_merge_lora.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI merge_lora command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py index cff0f3b773..420c28b9e8 100644 --- a/tests/cli/test_cli_merge_sharded_fsdp_weights.py +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" # pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_preprocess.py b/tests/cli/test_cli_preprocess.py index 4719461aaf..e2dd3a6c35 100644 --- a/tests/cli/test_cli_preprocess.py +++ b/tests/cli/test_cli_preprocess.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI preprocess command.""" + import shutil from pathlib import Path from unittest.mock import patch diff --git a/tests/cli/test_cli_shard.py b/tests/cli/test_cli_shard.py index 505a2a7372..3176ed27ee 100644 --- a/tests/cli/test_cli_shard.py +++ b/tests/cli/test_cli_shard.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI shard command.""" # pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_version.py b/tests/cli/test_cli_version.py index 819780e945..533dd5c0ec 100644 --- a/tests/cli/test_cli_version.py +++ b/tests/cli/test_cli_version.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI --version""" + from axolotl.cli.main import cli diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index b88e4ac729..ecb0025e44 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI utils.""" # pylint: disable=redefined-outer-name + import json from unittest.mock import Mock, patch diff --git a/tests/e2e/integrations/convert_differential_transformer/conftest.py b/tests/e2e/integrations/convert_differential_transformer/conftest.py index 17a424ddbe..ed1eb3f363 100644 --- a/tests/e2e/integrations/convert_differential_transformer/conftest.py +++ b/tests/e2e/integrations/convert_differential_transformer/conftest.py @@ -1,6 +1,7 @@ """Shared fixtures for differential transformer conversion tests.""" import pytest +from click.testing import CliRunner @pytest.fixture() @@ -26,3 +27,8 @@ def base_config(): "pad_token": "<|endoftext|>", }, } + + +@pytest.fixture +def cli_runner(): + return CliRunner() diff --git a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py index 84e5fdaa15..42ce3e6127 100644 --- a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py +++ b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Optional +from unittest.mock import patch import pytest import yaml @@ -12,9 +13,41 @@ from axolotl.cli.integrations.convert_differential_transformer import ( convert_differential_transformer, ) +from axolotl.cli.main import cli from axolotl.common.cli import ConvertDiffTransformerCliArgs +def test_cli_validation(cli_runner): + # Test missing config file + result = cli_runner.invoke(cli, ["convert-differential-transformer"]) + assert result.exit_code != 0 + assert "Error: Missing argument 'CONFIG'." in result.output + + # Test non-existent config file + result = cli_runner.invoke( + cli, ["convert-differential-transformer", "nonexistent.yml"] + ) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + +def test_basic_execution(cli_runner, tmp_path: Path, base_config): + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + with patch( + "axolotl.cli.integrations.convert_differential_transformer.do_cli" + ) as mock_do_cli: + result = cli_runner.invoke( + cli, ["convert-differential-transformer", str(config_path)] + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + + def test_conversion_cli_basic(tmp_path: Path, base_config): output_dir = tmp_path / "converted" base_config["output_dir"] = str(output_dir) @@ -113,7 +146,6 @@ def test_conversion_cli_repoduce_attentions( ) def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): output_dir = tmp_path / "converted" - base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B" base_config["output_dir"] = str(output_dir) base_config[attention] = True From 53b4d80e558d92832536c8e675f5ff5fd1cbcd51 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 19 Dec 2024 02:51:39 +0000 Subject: [PATCH 19/25] removing extra pytest xdist args --- cicd/cicd.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 91926127fb..b01846e6e6 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -4,7 +4,6 @@ set -e python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ -# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ From 9262124cfb81cc99685a9a3dd0c746fcf5c3bc8c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 20 Dec 2024 20:39:40 +0000 Subject: [PATCH 20/25] adding yaml dumper preserving input config format --- ...sformer.py => convert_diff_transformer.py} | 30 +- src/axolotl/cli/main.py | 4 +- src/axolotl/common/cli.py | 2 +- .../integrations/diff_transformer/README.md | 10 + .../__init__.py | 4 +- .../args.py | 2 +- .../convert.py | 2 +- .../diff_transformer/diff_attn.py | 375 ++++++++++ .../differential_transformer/README.md | 10 - .../differential_attention.py | 641 ------------------ .../monkeypatch/attention/differential.py | 2 +- src/axolotl/utils/models.py | 8 +- src/axolotl/utils/yaml.py | 151 +++++ .../__init__.py | 0 .../conftest.py | 3 - .../test_convert_and_evaluate.py | 6 +- .../test_convert_diff_transformer.py} | 36 +- 17 files changed, 579 insertions(+), 707 deletions(-) rename src/axolotl/cli/integrations/{convert_differential_transformer.py => convert_diff_transformer.py} (87%) create mode 100644 src/axolotl/integrations/diff_transformer/README.md rename src/axolotl/integrations/{differential_transformer => diff_transformer}/__init__.py (81%) rename src/axolotl/integrations/{differential_transformer => diff_transformer}/args.py (84%) rename src/axolotl/integrations/{differential_transformer => diff_transformer}/convert.py (99%) create mode 100644 src/axolotl/integrations/diff_transformer/diff_attn.py delete mode 100644 src/axolotl/integrations/differential_transformer/README.md delete mode 100644 src/axolotl/integrations/differential_transformer/differential_attention.py create mode 100644 src/axolotl/utils/yaml.py rename tests/e2e/integrations/{convert_differential_transformer => convert_diff_transformer}/__init__.py (100%) rename tests/e2e/integrations/{convert_differential_transformer => convert_diff_transformer}/conftest.py (85%) rename tests/e2e/integrations/{convert_differential_transformer => convert_diff_transformer}/test_convert_and_evaluate.py (89%) rename tests/e2e/integrations/{convert_differential_transformer/test_convert_differential_transformer.py => convert_diff_transformer/test_convert_diff_transformer.py} (79%) diff --git a/src/axolotl/cli/integrations/convert_differential_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py similarity index 87% rename from src/axolotl/cli/integrations/convert_differential_transformer.py rename to src/axolotl/cli/integrations/convert_diff_transformer.py index b50dd43dd8..d91278fed0 100644 --- a/src/axolotl/cli/integrations/convert_differential_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -14,7 +14,8 @@ from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer -from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn +from axolotl.integrations.diff_transformer.convert import convert_to_diff_attn +from axolotl.utils.yaml import dump_yaml_preserved_order LOG = logging.getLogger(__name__) @@ -51,7 +52,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"): raise -def convert_differential_transformer(cfg, cli_args, config_path): +def convert_diff_transformer(cfg, cli_args, config_path): debug_info = {} # Load model and tokenizer @@ -114,16 +115,23 @@ def convert_differential_transformer(cfg, cli_args, config_path): LOG.info("Saving updated config to %s", output_config_path) with open(config_path, "r", encoding="utf-8") as file: - data = yaml.safe_load(file) or {} + modified_cfg = yaml.safe_load(file) or {} - data["base_model"] = cfg.output_dir - data["differential_attention"] = True - data["plugins"] = [ - "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin" - ] + modified_cfg["base_model"] = cfg.output_dir + modified_cfg["diff_attention"] = True + plugin_class = ( + "axolotl.integrations.diff_transformer.DifferentialTransformerPlugin" + ) + if "plugins" in modified_cfg: + modified_cfg["plugins"].append(plugin_class) + else: + modified_cfg["plugins"] = [plugin_class] - with open(output_config_path, "w", encoding="utf-8") as file: - yaml.dump(data, file) + dump_yaml_preserved_order( + data=modified_cfg, + reference_yaml_path=config_path, + output_path=output_config_path, + ) else: LOG.info("Not saving converted model to disk") LOG.info("Pass --output-dir path/to/save to save model") @@ -191,7 +199,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): parser = HfArgumentParser(ConvertDiffTransformerCliArgs) cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) - convert_differential_transformer(cfg, cli_args, config) + convert_diff_transformer(cfg, cli_args, config) if __name__ == "__main__": diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index c7549bd4a0..d9d3a21354 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -252,11 +252,11 @@ def merge_lora( @click.argument("config", type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(ConvertDiffTransformerCliArgs) @add_options_from_config(AxolotlInputConfig) -def convert_differential_transformer(config: str, **kwargs): +def convert_diff_transformer(config: str, **kwargs): """Convert model attention layers to differential attention layers.""" kwargs = {k: v for k, v in kwargs.items() if v is not None} - from axolotl.cli.integrations.convert_differential_transformer import do_cli + from axolotl.cli.integrations.convert_diff_transformer import do_cli do_cli(config=config, **kwargs) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index c51c4e2ab9..ea3b91c0c2 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -57,7 +57,7 @@ class EvaluateCliArgs: @dataclass class ConvertDiffTransformerCliArgs: """ - dataclass with arguments for convert-differential-transformer CLI + dataclass with arguments for convert-diff-transformer CLI """ debug: bool = field(default=False) diff --git a/src/axolotl/integrations/diff_transformer/README.md b/src/axolotl/integrations/diff_transformer/README.md new file mode 100644 index 0000000000..14473f7537 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/README.md @@ -0,0 +1,10 @@ +# Differential Transformer + +### Usage + +```yaml +plugins: + - axolotl.integrations.diff_transformer.DifferentialTransformerPlugin + +diff_attention: true +``` diff --git a/src/axolotl/integrations/differential_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py similarity index 81% rename from src/axolotl/integrations/differential_transformer/__init__.py rename to src/axolotl/integrations/diff_transformer/__init__.py index 63741793c4..70459e0266 100644 --- a/src/axolotl/integrations/differential_transformer/__init__.py +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -13,11 +13,11 @@ class DifferentialTransformerPlugin(BasePlugin): """ def get_input_args(self): - return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs" + return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" def pre_model_load(self, cfg): """Apply differential attention patch before model loading if enabled.""" - if cfg.differential_attention: + if cfg.diff_attention: from axolotl.monkeypatch.attention.differential import ( patch_llama_attention_classes, ) diff --git a/src/axolotl/integrations/differential_transformer/args.py b/src/axolotl/integrations/diff_transformer/args.py similarity index 84% rename from src/axolotl/integrations/differential_transformer/args.py rename to src/axolotl/integrations/diff_transformer/args.py index bd6e01520f..47c1fe1104 100644 --- a/src/axolotl/integrations/differential_transformer/args.py +++ b/src/axolotl/integrations/diff_transformer/args.py @@ -11,4 +11,4 @@ class DifferentialTransformerArgs(BaseModel): """Input args for differential transformer.""" - differential_attention: Optional[bool] = None + diff_attention: Optional[bool] = None diff --git a/src/axolotl/integrations/differential_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py similarity index 99% rename from src/axolotl/integrations/differential_transformer/convert.py rename to src/axolotl/integrations/diff_transformer/convert.py index 4beaea7ae9..5c10f2137a 100644 --- a/src/axolotl/integrations/differential_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -11,7 +11,7 @@ LlamaSdpaAttention, ) -from .differential_attention import ( +from .diff_attn import ( LlamaDifferentialAttention, LlamaDifferentialFlashAttention2, LlamaDifferentialSdpaAttention, diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py new file mode 100644 index 0000000000..edf532c418 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -0,0 +1,375 @@ +"""Re-implemention of differential attention.""" +# pylint: disable=invalid-name + +import logging +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as F +from flash_attn.flash_attn_interface import flash_attn_func +from torch import nn +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + batch_size, n_kv_heads, slen, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, None, :, :] + .expand(batch_size, n_kv_heads, n_rep, slen, head_dim) + .reshape(batch_size, n_kv_heads * n_rep, slen, head_dim) + ) + + +def lambda_init_fn(depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + +class DifferentialAttentionBase(nn.Module): + """Base class for differential attention implementations.""" + + def __init__(self, config: Any, layer_idx: int): + super().__init__() + self._init_config(config, layer_idx) + self._init_projections() + self._init_differential_params() + self._init_normalization(config) + + def _init_config(self, config: Any, layer_idx: int): + """Initialize configuration parameters.""" + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.base_num_heads = config.num_attention_heads + self.base_num_kv_heads = config.num_key_value_heads + self.layer_idx = layer_idx + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.split_heads = config.split_heads + + if config.split_heads: + # Split heads mode - single projections + self.head_dim = config.hidden_size // config.num_attention_heads // 2 + # NOTE: This rounds down `base_num_heads / 2` as opposed to the original + # implementation, which asserts `self.base_num_heads` is even. + self.heads_per_component = self.base_num_heads // 2 + self.value_head_dim = 2 * self.head_dim + else: + # Double projection mode + self.head_dim = config.hidden_size // config.num_attention_heads + self.heads_per_component = self.base_num_heads + self.value_head_dim = self.head_dim + + def _init_projections(self): + """Initialize Q, K, V projections.""" + if self.split_heads: + # Split heads mode - single projections + q_out_dim = self.hidden_size + k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads + else: + # Double projection mode + q_out_dim = self.hidden_size * 2 + k_out_dim = ( + self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2 + ) + + self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False) + self.v_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads, + bias=False, + ) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def _init_differential_params(self): + """Initialize differential attention parameters.""" + self.lambda_init = nn.Parameter( + torch.full((), lambda_init_fn(self.layer_idx)), + requires_grad=False, + ) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.rotary_emb = LlamaRotaryEmbedding( + self.max_position_embeddings, self.head_dim, self.rope_theta + ) + + def _init_normalization(self, config): + """Initialize normalization layers.""" + sublayer_norm = getattr(config, "sublayer_norm", True) + self.subln = ( + LlamaRMSNorm(self.value_head_dim, eps=1e-5) + if sublayer_norm + else nn.Identity() + ) + + def _prepare_attention_inputs(self, hidden_states: torch.Tensor): + """Prepare inputs for attention computation.""" + bsz, q_len, _ = hidden_states.size() + + # Project and split + qp = self.q_proj(hidden_states) + kp = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q1, q2 = qp.chunk(2, dim=-1) + k1, k2 = kp.chunk(2, dim=-1) + + # Reshape + q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2) + + return q1, q2, k1, k2, v + + def _apply_rotary_embeddings( + self, q1, q2, k1, k2, position_ids, position_embeddings + ): + """Apply rotary embeddings to queries and keys.""" + if position_embeddings is None: + if position_ids is None: + position_ids = torch.arange(q1.size(-2), device=q1.device) + cos, sin = self.rotary_emb(q1, position_ids) + else: + cos, sin = position_embeddings + + if self.split_heads: + cos, _ = cos.chunk(2, dim=2) + sin, _ = sin.chunk(2, dim=2) + + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) + q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) + + return q1, q2, k1, k2, cos, sin + + def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs): + """Handle caching for autoregressive generation.""" + if past_key_value is not None: + k = torch.stack([k1, k2], dim=1) + k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + k1, k2 = k.unbind(dim=1) + + # Repeat KV heads + k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) + k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) + v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) + + return k1, k2, v + + def _compute_lambda(self, q1): + """Compute lambda values for differential attention.""" + lambda_1 = torch.exp( + torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() + ).type_as(q1) + lambda_2 = torch.exp( + torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() + ).type_as(q1) + return lambda_1 - lambda_2 + self.lambda_init + + def _process_attention_output(self, attn, bsz, q_len): + """Process and project attention output.""" + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + return self.o_proj(attn) + + +class LlamaDifferentialAttention(DifferentialAttentionBase): + """Standard implementation of differential attention.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + bsz, q_len, _ = hidden_states.size() + q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) + q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( + q1, q2, k1, k2, position_ids, position_embeddings + ) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs) + + # Standard attention computation + attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) + attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : k1.shape[-2]] + attn1 = attn1 + causal_mask + attn2 = attn2 + causal_mask + + attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1) + attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2) + + dropout_p = self.attention_dropout if self.training else 0.0 + attn1 = F.dropout(attn1, p=dropout_p, training=self.training) + attn2 = F.dropout(attn2, p=dropout_p, training=self.training) + + lambda_full = self._compute_lambda(q1) + attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v) + + attn = self._process_attention_output(attn, bsz, q_len) + + if output_attentions: + return attn, attn1 - lambda_full * attn2, past_key_value + return attn, None, past_key_value + + +class LlamaDifferentialSdpaAttention(DifferentialAttentionBase): + """SDPA-based implementation of differential attention.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + if output_attentions: + return LlamaDifferentialAttention.forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) + q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( + q1, q2, k1, k2, position_ids, position_embeddings + ) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs) + + # SDPA-specific attention computation + causal_mask = ( + None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]] + ) + is_causal = attention_mask is None and q_len > 1 + dropout_p = self.attention_dropout if self.training else 0.0 + + if q1.device.type == "cuda" and causal_mask is not None: + q1, q2 = q1.contiguous(), q2.contiguous() + k1, k2 = k1.contiguous(), k2.contiguous() + v = v.contiguous() + + attn1 = F.scaled_dot_product_attention( + q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal + ) + attn2 = F.scaled_dot_product_attention( + q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal + ) + + lambda_full = self._compute_lambda(q1) + attn = attn1 - lambda_full * attn2 + + attn = self._process_attention_output(attn, bsz, q_len) + return attn, None, past_key_value + + +class LlamaDifferentialFlashAttention2(DifferentialAttentionBase): + """Flash Attention 2-based implementation of differential attention.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + if output_attentions: + return LlamaDifferentialAttention.forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) + q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( + q1, q2, k1, k2, position_ids, position_embeddings + ) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs) + + # Flash Attention specific processing + q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2) + k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2) + v = v.transpose(1, 2) + + dropout_p = self.attention_dropout if self.training else 0.0 + + if self.split_heads: + v1, v2 = v.chunk(2, dim=-1) + attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True) + attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True) + attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True) + attn2 = torch.cat([attn21, attn22], dim=-1) + else: + attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True) + attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True) + + attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2) + + lambda_full = self._compute_lambda(q1) + attn = attn1 - lambda_full * attn2 + + attn = self._process_attention_output(attn, bsz, q_len) + return attn, None, past_key_value diff --git a/src/axolotl/integrations/differential_transformer/README.md b/src/axolotl/integrations/differential_transformer/README.md deleted file mode 100644 index f7bd74cbdb..0000000000 --- a/src/axolotl/integrations/differential_transformer/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Differential Transformer - -### Usage - -```yaml -plugins: - - axolotl.integrations.differential_transformer.DifferentialTransformerPlugin - -differential_attention: true -``` diff --git a/src/axolotl/integrations/differential_transformer/differential_attention.py b/src/axolotl/integrations/differential_transformer/differential_attention.py deleted file mode 100644 index af7473436c..0000000000 --- a/src/axolotl/integrations/differential_transformer/differential_attention.py +++ /dev/null @@ -1,641 +0,0 @@ -"""Re-implemention of differential attention.""" -# pylint: disable=invalid-name -import logging -import math -from typing import Any, Optional, Tuple - -import torch -import torch.nn.functional as F -import transformers -from flash_attn.flash_attn_interface import flash_attn_func -from torch import nn -from transformers.cache_utils import Cache -from transformers.models.llama.modeling_llama import ( - LlamaRMSNorm, - LlamaRotaryEmbedding, - apply_rotary_pos_emb, -) - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - batch_size, n_kv_heads, slen, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, None, :, :] - .expand(batch_size, n_kv_heads, n_rep, slen, head_dim) - .reshape(batch_size, n_kv_heads * n_rep, slen, head_dim) - ) - - -def lambda_init_fn(depth): - return 0.8 - 0.6 * math.exp(-0.3 * depth) - - -class LlamaDifferentialAttention(nn.Module): - """Differential Attention implementation as described in the Diff Transformer paper. - - This implements a modified attention mechanism that computes the difference between - two attention patterns, scaled by learned lambda parameters. The mechanism helps - reduce noise in the attention weights for irrelevant / less relevant tokens. - - Key components: - - Split head dimension for differential computation - - Learned lambda parameters that control attention scaling - - Sublayer normalization on the attention output - - See: - - https://arxiv.org/abs/2410.05258 - - https://github.com/microsoft/unilm/tree/master/Diff-Transformer - - Args: - config: Model configuration object containing hidden size, number of heads etc. - layer_idx: Index of this layer in the transformer stack - dtype: Data type for the layer parameters - """ - - def __init__( - self, - config: Any, - layer_idx: int, - ): - super().__init__() - - # Base model config - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.base_num_heads = config.num_attention_heads - self.base_num_kv_heads = config.num_key_value_heads - - if config.split_heads: - self.head_dim = config.hidden_size // config.num_attention_heads // 2 - else: - self.head_dim = config.hidden_size // config.num_attention_heads - - self.layer_idx = layer_idx - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.split_heads = config.split_heads - - if config.split_heads: - # Split heads mode - # assert ( - # self.base_num_heads % 2 == 0 - # ), "Number of heads must be even for splitting" - self.heads_per_component = self.base_num_heads // 2 - - # Single projections - self.q_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias=False, - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.base_num_heads * self.base_num_kv_heads, - bias=False, - ) - else: - # Double projection mode - self.heads_per_component = self.base_num_heads - - # Double-sized projections - self.q_proj = nn.Linear( - self.hidden_size, - self.hidden_size * 2, - bias=False, - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, - bias=False, - ) - - # Single V projection - self.v_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.base_num_heads * self.base_num_kv_heads, - bias=False, - ) - - # Output projection - self.o_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias=False, - ) - - # Initialize differential attention parameters - self.lambda_init = nn.Parameter( - torch.full((), lambda_init_fn(self.layer_idx)), - requires_grad=False, - ) - self.lambda_q1 = nn.Parameter( - torch.zeros(self.head_dim).normal_(mean=0, std=0.1) - ) - self.lambda_k1 = nn.Parameter( - torch.zeros(self.head_dim).normal_(mean=0, std=0.1) - ) - self.lambda_q2 = nn.Parameter( - torch.zeros(self.head_dim).normal_(mean=0, std=0.1) - ) - self.lambda_k2 = nn.Parameter( - torch.zeros(self.head_dim).normal_(mean=0, std=0.1) - ) - - self.rotary_emb = LlamaRotaryEmbedding(config=config) - sublayer_norm = getattr(config, "sublayer_norm", True) - - if self.split_heads: - subln_dim = 2 * self.head_dim - else: - subln_dim = self.head_dim - - self.subln = ( - LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5) - if sublayer_norm - else nn.Identity() - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, # pylint: disable=unused-argument - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, # pylint: disable=unused-argument - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[tuple[torch.Tensor, torch.Tensor]], - ]: - bsz, q_len, _ = hidden_states.size() - - # Project to Q1,Q2 and K1,K2 - qp = self.q_proj(hidden_states) - kp = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Split into Q1,Q2 and K1,K2 - q1, q2 = qp.chunk(2, dim=-1) - k1, k2 = kp.chunk(2, dim=-1) - - # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape V - if self.split_heads: - v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) - else: - v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Apply rotary embeddings - if position_embeddings is None: - if position_ids is None: - position_ids = torch.arange(q_len, device=q1.device) - cos, sin = self.rotary_emb(q1, position_ids) - else: - cos, sin = position_embeddings - - if self.split_heads: - cos, _ = cos.chunk(2, dim=2) - sin, _ = sin.chunk(2, dim=2) - - q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) - q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - k = torch.stack([k1, k2], dim=1) - k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) - k1, k2 = k.unbind(dim=1) - - # Repeat KV heads to match Q heads - k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) - k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) - v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) - - # Calculate attention scores for both parts - attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) - attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : k1.shape[-2]] - attn1 = attn1 + causal_mask - attn2 = attn2 + causal_mask - - # Apply softmax - attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1) - attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2) - - # Apply dropout - attn1 = F.dropout(attn1, p=self.attention_dropout, training=self.training) - attn2 = F.dropout(attn2, p=self.attention_dropout, training=self.training) - - # Calculate lambda - lambda_1 = torch.exp( - torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() - ).type_as(q1) - lambda_2 = torch.exp( - torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() - ).type_as(q1) - lambda_full = lambda_1 - lambda_2 + self.lambda_init - - # Compute differential attention (following paper's formula) - attn_weights = attn1 - lambda_full * attn2 - - # Apply attention weights to values - attn = torch.matmul(attn_weights, v) - - # Apply sublayer norm and scaling - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - - # Reshape to output - attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) - attn = self.o_proj(attn) - - if output_attentions: - return attn, attn_weights, past_key_value - return attn, None, past_key_value - - -class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention): - """Differential Attention implementation as described in the Diff Transformer paper. - This implements the same logic as `LlamaDifferentialAttention`, but uses - `scaled_dot_product_attention` instead of "manually" computing it under the hood. - - This implements a modified attention mechanism that computes the difference between - two attention patterns, scaled by learned lambda parameters. The mechanism helps - reduce noise in the attention weights for irrelevant / less relevant tokens. - - Key components: - - Split head dimension for differential computation - - Learned lambda parameters that control attention scaling - - Sublayer normalization on the attention output - - See: - - https://arxiv.org/abs/2410.05258 - - https://github.com/microsoft/unilm/tree/master/Diff-Transformer - - Args: - config: Model configuration object containing hidden size, number of heads etc. - layer_idx: Index of this layer in the transformer stack - dtype: Data type for the layer parameters - """ - - # pylint: disable=duplicate-code - def forward( - self, - hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, # pylint: disable=unused-argument - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[tuple[torch.Tensor, torch.Tensor]], - ]: - if output_attentions: - transformers.logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - # Project to Q1,Q2 and K1,K2 - qp = self.q_proj(hidden_states) - kp = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Split into Q1,Q2 and K1,K2 - q1, q2 = qp.chunk(2, dim=-1) - k1, k2 = kp.chunk(2, dim=-1) - - # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape V - if self.split_heads: - v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) - else: - v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Apply rotary embeddings - if position_embeddings is None: - if position_ids is None: - position_ids = torch.arange(q_len, device=q1.device) - cos, sin = self.rotary_emb(q1, position_ids) - else: - cos, sin = position_embeddings - - if self.split_heads: - cos, _ = cos.chunk(2, dim=2) - sin, _ = sin.chunk(2, dim=2) - - q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) - q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - k = torch.stack([k1, k2], dim=1) - k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) - k1, k2 = k.unbind(dim=1) - - # Repeat KV heads to match Q heads - k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) - k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) - v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) - - causal_mask = None - if attention_mask is not None: - causal_mask = attention_mask - causal_mask = causal_mask[:, :, :, : k1.shape[-2]] - - # SDPA with memory-efficient backend requires contiguous inputs on CUDA - if q1.device.type == "cuda" and causal_mask is not None: - q1, q2 = q1.contiguous(), q2.contiguous() - k1, k2 = k1.contiguous(), k2.contiguous() - v = v.contiguous() - - # Calculate attention using SDPA - is_causal = attention_mask is None and q_len > 1 - - dropout_p = self.attention_dropout if self.training else 0.0 - attn1 = F.scaled_dot_product_attention( - q1, - k1, - v, - attn_mask=causal_mask, - dropout_p=dropout_p, - is_causal=is_causal, - ) - attn2 = F.scaled_dot_product_attention( - q2, - k2, - v, - attn_mask=causal_mask, - dropout_p=dropout_p, - is_causal=is_causal, - ) - - # Calculate lambda - lambda_1 = torch.exp( - torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() - ).type_as(q1) - lambda_2 = torch.exp( - torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() - ).type_as(q1) - lambda_full = lambda_1 - lambda_2 + self.lambda_init - - # Combine the attention outputs - attn = attn1 - lambda_full * attn2 - - # Apply sublayer norm and scaling - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - - # Reshape to output - attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) - attn = self.o_proj(attn) - - if output_attentions: - return ( - attn, - None, - past_key_value, - ) # Note: can't return attn_weights with SDPA - return attn, None, past_key_value - - -class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention): - """Differential Attention implementation using Flash Attention 2. - This implements the same logic as `LlamaDifferentialAttention`, but uses - Flash Attention 2 for more efficient computation. - - This implements a modified attention mechanism that computes the difference between - two attention patterns, scaled by learned lambda parameters. The mechanism helps - reduce noise in the attention weights for irrelevant / less relevant tokens. - - Key components: - - Split head dimension for differential computation - - Learned lambda parameters that control attention scaling - - Sublayer normalization on the attention output - - Flash Attention 2 for efficient attention computation - - See: - - https://arxiv.org/abs/2410.05258 - - https://github.com/microsoft/unilm/tree/master/Diff-Transformer - - Args: - config: Model configuration object containing hidden size, number of heads etc. - layer_idx: Index of this layer in the transformer stack - dtype: Data type for the layer parameters - """ - - # pylint: disable=duplicate-code - def forward( - self, - hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[tuple[torch.Tensor, torch.Tensor]], - ]: - if output_attentions: - transformers.logger.warning_once( - "LlamaModel is using LlamaFlashAttention, but Flash Attention does not support `output_attentions=True`. " - "Falling back to the manual attention implementation." - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - # Project to Q1,Q2 and K1,K2 - qp = self.q_proj(hidden_states) - kp = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Split into Q1,Q2 and K1,K2 - q1, q2 = qp.chunk(2, dim=-1) - k1, k2 = kp.chunk(2, dim=-1) - - # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape V - if self.split_heads: - v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) - else: - v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Apply rotary embeddings - if position_embeddings is None: - if position_ids is None: - position_ids = torch.arange(q_len, device=q1.device) - cos, sin = self.rotary_emb(q1, position_ids) - else: - cos, sin = position_embeddings - - if self.split_heads: - cos, _ = cos.chunk(2, dim=2) - sin, _ = sin.chunk(2, dim=2) - - q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) - q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - k = torch.stack([k1, k2], dim=1) - k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) - k1, k2 = k.unbind(dim=1) - - # Repeat KV heads to match Q heads - k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) - k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) - v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) - - q1 = q1.transpose(1, 2) - q2 = q2.transpose(1, 2) - k1 = k1.transpose(1, 2) - k2 = k2.transpose(1, 2) - v = v.transpose(1, 2) - - # Calculate attention using Flash Attention - dropout_p = self.attention_dropout if self.training else 0.0 - if self.split_heads: - v1, v2 = v.chunk(2, dim=-1) - attn11 = flash_attn_func( - q1, - k1, - v1, - dropout_p=dropout_p, - causal=True, - ) - attn12 = flash_attn_func( - q1, - k1, - v2, - dropout_p=dropout_p, - causal=True, - ) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = flash_attn_func( - q2, - k2, - v1, - dropout_p=dropout_p, - causal=True, - ) - attn22 = flash_attn_func( - q2, - k2, - v2, - dropout_p=dropout_p, - causal=True, - ) - attn2 = torch.cat([attn21, attn22], dim=-1) - else: - attn1 = flash_attn_func( - q1, - k1, - v, - dropout_p=dropout_p, - causal=True, - ) - attn2 = flash_attn_func( - q2, - k2, - v, - dropout_p=dropout_p, - causal=True, - ) - - attn1 = attn1.transpose(1, 2) - attn2 = attn2.transpose(1, 2) - - # Calculate lambda - lambda_1 = torch.exp( - torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() - ).type_as(q1) - lambda_2 = torch.exp( - torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() - ).type_as(q1) - lambda_full = lambda_1 - lambda_2 + self.lambda_init - - # Combine the attention outputs - attn = attn1 - lambda_full * attn2 - - # Apply sublayer norm and scaling - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - - # Reshape to output - attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) - attn = self.o_proj(attn) - - if output_attentions: - return ( - attn, - None, - past_key_value, - ) # Note: can't return attn_weights with Flash Attention - return attn, None, past_key_value diff --git a/src/axolotl/monkeypatch/attention/differential.py b/src/axolotl/monkeypatch/attention/differential.py index a07b629b6b..635573a4be 100644 --- a/src/axolotl/monkeypatch/attention/differential.py +++ b/src/axolotl/monkeypatch/attention/differential.py @@ -3,7 +3,7 @@ from transformers import PreTrainedModel from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES -from axolotl.integrations.differential_transformer.differential_attention import ( +from axolotl.integrations.diff_transformer.diff_attn import ( LlamaDifferentialAttention, LlamaDifferentialFlashAttention2, LlamaDifferentialSdpaAttention, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6eaa020da0..37cbc08713 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -714,7 +714,7 @@ def set_attention_config(self) -> None: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - if self.cfg.differential_attention: + if self.cfg.differentiaion: self.model_kwargs[ "attn_implementation" ] = "differential_flash_attention_2" @@ -727,7 +727,7 @@ def set_attention_config(self) -> None: "flash_attention_2" ) elif self.cfg.sdp_attention: - if self.cfg.differential_attention: + if self.cfg.diff_attention: self.model_kwargs["attn_implementation"] = "differential_sdpa" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_sdpa" @@ -738,7 +738,7 @@ def set_attention_config(self) -> None: "sdpa" ) elif self.cfg.eager_attention: - if self.cfg.differential_attention: + if self.cfg.diff_attention: self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_eager" @@ -748,7 +748,7 @@ def set_attention_config(self) -> None: self.model_config._attn_implementation = ( # pylint: disable=protected-access "eager" ) - elif self.cfg.differential_attention: + elif self.cfg.diff_attention: self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_eager" diff --git a/src/axolotl/utils/yaml.py b/src/axolotl/utils/yaml.py new file mode 100644 index 0000000000..107afafcf4 --- /dev/null +++ b/src/axolotl/utils/yaml.py @@ -0,0 +1,151 @@ +"""Utilities for YAML files.""" + +from collections import OrderedDict +from typing import Any, Dict, List, Set, Tuple, Union + +import yaml + + +class YAMLOrderTracker: + """Tracks the order of keys and section breaks in YAML files.""" + + def __init__(self, yaml_path: str): + self.yaml_path = yaml_path + self.structure, self.needs_break = self._parse_yaml_structure() + + def _get_indentation_level(self, line: str) -> int: + """Get the indentation level of a line.""" + return len(line) - len(line.lstrip()) + + def _parse_yaml_structure( + self, + ) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]: + """Parse the YAML file to extract structure and identify section breaks.""" + with open(self.yaml_path, "r", encoding="utf-8") as file: + contents = file.readlines() + + structure: OrderedDict = OrderedDict() + needs_break = set() # Track which keys should have a break before them + current_path = [] + last_indentation = -1 + had_empty_line = False + + for line in contents: + # Track empty lines and comments + if not line.strip() or line.strip().startswith("#"): + had_empty_line = True + continue + + # Get indentation level and content + indentation = self._get_indentation_level(line) + content = line.strip() + + # Skip lines that don't define keys + if ":" not in content: + continue + + # Extract key + key = content.split(":")[0].strip() + + # If this is a top-level key and we had an empty line, mark it + if indentation == 0: + if had_empty_line: + needs_break.add(key) + had_empty_line = False + + # Handle indentation changes + if indentation > last_indentation: + current_path.append(key) + elif indentation < last_indentation: + levels_up = (last_indentation - indentation) // 2 + current_path = current_path[:-levels_up] + current_path[-1] = key + else: + if current_path: + current_path[-1] = key + + # Update structure + current_dict = structure + for path_key in current_path[:-1]: + if path_key not in current_dict: + current_dict[path_key] = OrderedDict() + current_dict = current_dict[path_key] + + if current_path: + if current_path[-1] not in current_dict: + current_dict[current_path[-1]] = OrderedDict() + + last_indentation = indentation + + return structure, needs_break + + +class OrderedDumper(yaml.SafeDumper): + """Custom YAML dumper that maintains dictionary order.""" + + +def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any: + """Custom representer for dictionaries that maintains order.""" + return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) + + +def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict: + """Reorder a dictionary based on a reference structure.""" + ordered = OrderedDict() + + # First add keys that are in the reference order + for key in reference_structure: + if key in data: + if isinstance(reference_structure[key], dict) and isinstance( + data[key], dict + ): + ordered[key] = reorder_dict(data[key], reference_structure[key]) + else: + ordered[key] = data[key] + + # Then add any remaining keys that weren't in the reference + for key in data: + if key not in ordered: + ordered[key] = data[key] + + return ordered + + +def dump_yaml_preserved_order( + data: Dict, reference_yaml_path: str, output_path: str +) -> None: + """Dump YAML file while preserving nested order and normalized spacing.""" + # Get reference structure and spacing + tracker = YAMLOrderTracker(reference_yaml_path) + + # Reorder the data + ordered_data = reorder_dict(data, tracker.structure) + + # Register the custom representer + OrderedDumper.add_representer(dict, ordered_dict_representer) + OrderedDumper.add_representer(OrderedDict, ordered_dict_representer) + + # First dump to string + yaml_str = yaml.dump( + ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False + ) + + # Add spacing according to reference + lines = yaml_str.split("\n") + result_lines: List[str] = [] + current_line = 0 + + while current_line < len(lines): + line = lines[current_line] + if line.strip() and ":" in line and not line.startswith(" "): # Top-level key + key = line.split(":")[0].strip() + if key in tracker.needs_break: + # Add single empty line before this key + if result_lines and result_lines[-1] != "": + result_lines.append("") + result_lines.append(line) + current_line += 1 + + # Write the final result + with open(output_path, "w", encoding="utf-8") as file: + file.write("\n".join(result_lines)) diff --git a/tests/e2e/integrations/convert_differential_transformer/__init__.py b/tests/e2e/integrations/convert_diff_transformer/__init__.py similarity index 100% rename from tests/e2e/integrations/convert_differential_transformer/__init__.py rename to tests/e2e/integrations/convert_diff_transformer/__init__.py diff --git a/tests/e2e/integrations/convert_differential_transformer/conftest.py b/tests/e2e/integrations/convert_diff_transformer/conftest.py similarity index 85% rename from tests/e2e/integrations/convert_differential_transformer/conftest.py rename to tests/e2e/integrations/convert_diff_transformer/conftest.py index ed1eb3f363..d4ffeb7597 100644 --- a/tests/e2e/integrations/convert_differential_transformer/conftest.py +++ b/tests/e2e/integrations/convert_diff_transformer/conftest.py @@ -9,9 +9,6 @@ def base_config(): """Basic config for testing.""" return { "base_model": "HuggingFaceTB/SmolLM2-135M", - "plugins": [ - "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin", - ], "datasets": [ { "path": "axolotl-ai-co/alpaca_100_test", diff --git a/tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py similarity index 89% rename from tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py rename to tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py index 1cf569693c..d5915f8a55 100644 --- a/tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py @@ -8,9 +8,7 @@ from axolotl.cli import load_cfg from axolotl.cli.evaluate import do_evaluate -from axolotl.cli.integrations.convert_differential_transformer import ( - convert_differential_transformer, -) +from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs @@ -26,7 +24,7 @@ def test_conversion_and_eval_cli(tmp_path: Path, base_config): cli_args = ConvertDiffTransformerCliArgs( debug=True, zero_init=True, sublayer_norm=False ) - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert debug_info["generations_match"] is True assert (output_dir / "model.safetensors").exists() diff --git a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py similarity index 79% rename from tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py rename to tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py index 42ce3e6127..e616a8ef12 100644 --- a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -10,23 +10,19 @@ import yaml from axolotl.cli import load_cfg -from axolotl.cli.integrations.convert_differential_transformer import ( - convert_differential_transformer, -) +from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer from axolotl.cli.main import cli from axolotl.common.cli import ConvertDiffTransformerCliArgs def test_cli_validation(cli_runner): # Test missing config file - result = cli_runner.invoke(cli, ["convert-differential-transformer"]) + result = cli_runner.invoke(cli, ["convert-diff-transformer"]) assert result.exit_code != 0 assert "Error: Missing argument 'CONFIG'." in result.output # Test non-existent config file - result = cli_runner.invoke( - cli, ["convert-differential-transformer", "nonexistent.yml"] - ) + result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) assert result.exit_code != 0 assert "Error: Invalid value for 'CONFIG'" in result.output @@ -37,11 +33,9 @@ def test_basic_execution(cli_runner, tmp_path: Path, base_config): yaml.dump(base_config, file) with patch( - "axolotl.cli.integrations.convert_differential_transformer.do_cli" + "axolotl.cli.integrations.convert_diff_transformer.do_cli" ) as mock_do_cli: - result = cli_runner.invoke( - cli, ["convert-differential-transformer", str(config_path)] - ) + result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)]) assert result.exit_code == 0 mock_do_cli.assert_called_once() @@ -56,14 +50,9 @@ def test_conversion_cli_basic(tmp_path: Path, base_config): with open(config_path, "w", encoding="utf-8") as file: yaml.dump(base_config, file) - # Load config the same way do_cli does cfg = load_cfg(str(config_path)) - - # Create CLI args cli_args = ConvertDiffTransformerCliArgs() - - # Call convert_differential_transformer directly - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert not debug_info assert (output_dir / "model.safetensors").exists() @@ -79,14 +68,9 @@ def test_conversion_cli_debug(tmp_path: Path, base_config): with open(config_path, "w", encoding="utf-8") as file: yaml.dump(base_config, file) - # Load config the same way do_cli does cfg = load_cfg(str(config_path)) - - # Create CLI args cli_args = ConvertDiffTransformerCliArgs(debug=True) - - # Call convert_differential_transformer directly - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert not debug_info["generations_match"] assert not debug_info["match_expected"] @@ -107,7 +91,7 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config): cli_args = ConvertDiffTransformerCliArgs( debug=True, zero_init=True, sublayer_norm=False ) - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert debug_info["generations_match"] is True assert (output_dir / "model.safetensors").exists() @@ -133,7 +117,7 @@ def test_conversion_cli_repoduce_attentions( cli_args = ConvertDiffTransformerCliArgs( debug=True, zero_init=True, sublayer_norm=False ) - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert debug_info["generations_match"] is True assert (output_dir / "model.safetensors").exists() @@ -155,7 +139,7 @@ def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str) cfg = load_cfg(str(config_path)) cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert debug_info["generations_match"] is False assert (output_dir / "model.safetensors").exists() From a1a3f1d4d3add061a0f1b7a30980e4c102f01b93 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 21 Dec 2024 16:56:57 +0000 Subject: [PATCH 21/25] refactor and fixing test isolation issues --- .../integrations/convert_diff_transformer.py | 4 +- src/axolotl/common/cli.py | 18 +- .../integrations/diff_transformer/convert.py | 10 +- .../diff_transformer/diff_attn.py | 3 +- src/axolotl/utils/yaml.py | 8 +- .../convert_diff_transformer/conftest.py | 4 +- .../test_convert_diff_transformer.py | 258 +++++++++--------- 7 files changed, 156 insertions(+), 149 deletions(-) diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index d91278fed0..360832dbb8 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -1,4 +1,5 @@ -"""CLI to convert a transformers model's attns to diff attns.""" +"""CLI to convert a transformers model's attention layers to differential attention layers.""" + import logging import warnings from pathlib import Path @@ -127,6 +128,7 @@ def convert_diff_transformer(cfg, cli_args, config_path): else: modified_cfg["plugins"] = [plugin_class] + # Write out the updated axolotl config while preserving original ordering / formatting dump_yaml_preserved_order( data=modified_cfg, reference_yaml_path=config_path, diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index ea3b91c0c2..ebe098ca6b 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -12,14 +12,12 @@ from axolotl.utils.models import load_model, load_tokenizer configure_logging() -LOG = logging.getLogger("axolotl.common.cli") +LOG = logging.getLogger(__name__) @dataclass class PreprocessCliArgs: - """ - dataclass with arguments for preprocessing only - """ + """dataclass with arguments for preprocessing only""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) @@ -30,9 +28,7 @@ class PreprocessCliArgs: @dataclass class TrainerCliArgs: - """ - dataclass with various non-training arguments - """ + """dataclass with various non-training arguments""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) @@ -45,9 +41,7 @@ class TrainerCliArgs: @dataclass class EvaluateCliArgs: - """ - dataclass with various evaluation arguments - """ + """dataclass with various evaluation arguments""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) @@ -56,9 +50,7 @@ class EvaluateCliArgs: @dataclass class ConvertDiffTransformerCliArgs: - """ - dataclass with arguments for convert-diff-transformer CLI - """ + """dataclass with arguments for convert-diff-transformer CLI""" debug: bool = field(default=False) zero_init: bool = field(default=False) diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 5c10f2137a..d942567d57 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -98,9 +98,13 @@ def convert_module(module): # Iterate through module children, convert any attn layers to diff attn for name, child in module.named_children(): - if isinstance(child, tuple(ATTENTION_MAPPING.keys())): - # Choose appropriate differential attention class - attention_class = ATTENTION_MAPPING[type(child)] + child_class_name = type(child).__name__ + if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]: + # Find matching attention class by name + for orig_class, diff_class in ATTENTION_MAPPING.items(): + if orig_class.__name__ == child_class_name: + attention_class = diff_class + break layer_type = type(child).__name__ logger.info( diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index edf532c418..a8d7536dd6 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -21,7 +21,6 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" batch_size, n_kv_heads, slen, head_dim = x.shape if n_rep == 1: return x @@ -249,6 +248,7 @@ def forward( class LlamaDifferentialSdpaAttention(DifferentialAttentionBase): """SDPA-based implementation of differential attention.""" + # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor, @@ -312,6 +312,7 @@ def forward( class LlamaDifferentialFlashAttention2(DifferentialAttentionBase): """Flash Attention 2-based implementation of differential attention.""" + # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor, diff --git a/src/axolotl/utils/yaml.py b/src/axolotl/utils/yaml.py index 107afafcf4..c5c9e74ae4 100644 --- a/src/axolotl/utils/yaml.py +++ b/src/axolotl/utils/yaml.py @@ -84,6 +84,11 @@ class OrderedDumper(yaml.SafeDumper): """Custom YAML dumper that maintains dictionary order.""" +def represent_none(self, _): + """Represent None values as empty fields.""" + return self.represent_scalar("tag:yaml.org,2002:null", "") + + def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any: """Custom representer for dictionaries that maintains order.""" return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) @@ -121,7 +126,8 @@ def dump_yaml_preserved_order( # Reorder the data ordered_data = reorder_dict(data, tracker.structure) - # Register the custom representer + # Register the custom representers + OrderedDumper.add_representer(type(None), represent_none) OrderedDumper.add_representer(dict, ordered_dict_representer) OrderedDumper.add_representer(OrderedDict, ordered_dict_representer) diff --git a/tests/e2e/integrations/convert_diff_transformer/conftest.py b/tests/e2e/integrations/convert_diff_transformer/conftest.py index d4ffeb7597..3964df0527 100644 --- a/tests/e2e/integrations/convert_diff_transformer/conftest.py +++ b/tests/e2e/integrations/convert_diff_transformer/conftest.py @@ -4,7 +4,7 @@ from click.testing import CliRunner -@pytest.fixture() +@pytest.fixture(scope="class") def base_config(): """Basic config for testing.""" return { @@ -26,6 +26,6 @@ def base_config(): } -@pytest.fixture +@pytest.fixture(scope="class") def cli_runner(): return CliRunner() diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py index e616a8ef12..02939ee1ca 100644 --- a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -15,133 +15,135 @@ from axolotl.common.cli import ConvertDiffTransformerCliArgs -def test_cli_validation(cli_runner): - # Test missing config file - result = cli_runner.invoke(cli, ["convert-diff-transformer"]) - assert result.exit_code != 0 - assert "Error: Missing argument 'CONFIG'." in result.output - - # Test non-existent config file - result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - - -def test_basic_execution(cli_runner, tmp_path: Path, base_config): - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - with patch( - "axolotl.cli.integrations.convert_diff_transformer.do_cli" - ) as mock_do_cli: - result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)]) - assert result.exit_code == 0 - - mock_do_cli.assert_called_once() - assert mock_do_cli.call_args.kwargs["config"] == str(config_path) - - -def test_conversion_cli_basic(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs() - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert not debug_info - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -def test_conversion_cli_debug(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert not debug_info["generations_match"] - assert not debug_info["match_expected"] - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -def test_conversion_cli_reproduce(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False +@pytest.mark.usefixtures("base_config", "cli_runner") +class TestDiffTransformer: + """Tests for convert-diff-transformer CLI command""" + + def test_cli_validation(self, cli_runner): + # Test missing config file + result = cli_runner.invoke(cli, ["convert-diff-transformer"]) + assert result.exit_code != 0 + assert "Error: Missing argument 'CONFIG'." in result.output + + # Test non-existent config file + result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + def test_basic_execution(self, cli_runner, tmp_path: Path, base_config): + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + with patch( + "axolotl.cli.integrations.convert_diff_transformer.do_cli" + ) as mock_do_cli: + result = cli_runner.invoke( + cli, ["convert-diff-transformer", str(config_path)] + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + + def test_conversion_cli_basic(self, tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs() + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + def test_conversion_cli_debug(self, tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info["generations_match"] + assert not debug_info["match_expected"] + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + def test_conversion_cli_reproduce(self, tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + @pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -@pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] -) -def test_conversion_cli_repoduce_attentions( - tmp_path: Path, base_config, attention: Optional[str] -): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False + def test_conversion_cli_repoduce_attentions( + self, tmp_path: Path, base_config, attention: Optional[str] + ): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + @pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -@pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] -) -def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is False - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() + def test_conversion_cli_split_heads( + self, tmp_path: Path, base_config, attention: str + ): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is False + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() From c6def2732d8322f9ae65e7b57b73777aa4f11fb6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 23 Dec 2024 14:14:51 -0500 Subject: [PATCH 22/25] added modeling code; cleanup + refactor --- .../integrations/diff_transformer/convert.py | 3 +- .../diff_transformer/diff_attn.py | 100 +++-- .../diff_transformer/modeling_diff_attn.py | 370 ++++++++++++++++++ 3 files changed, 439 insertions(+), 34 deletions(-) create mode 100644 src/axolotl/integrations/diff_transformer/modeling_diff_attn.py diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index d942567d57..298a0232ed 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -50,7 +50,7 @@ def copy_attention_weights( new_attn.q_proj.weight.data.copy_(new_q) # For K projection (K1 and K2) - old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads + old_kv_size = old_attn.k_proj.weight.data.size(0) new_k = torch.empty_like(new_attn.k_proj.weight.data) new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1 if zero_init: @@ -99,6 +99,7 @@ def convert_module(module): # Iterate through module children, convert any attn layers to diff attn for name, child in module.named_children(): child_class_name = type(child).__name__ + if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]: # Find matching attention class by name for orig_class, diff_class in ATTENTION_MAPPING.items(): diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index a8d7536dd6..cccb0adebd 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -7,7 +7,6 @@ import torch import torch.nn.functional as F -from flash_attn.flash_attn_interface import flash_attn_func from torch import nn from transformers.cache_utils import Cache from transformers.models.llama.modeling_llama import ( @@ -17,7 +16,14 @@ ) logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +LOG = logging.getLogger(__name__) + +try: + from flash_attn.flash_attn_interface import flash_attn_func + + FLASH_ATTENTION_AVAILABLE = True +except ImportError: + FLASH_ATTENTION_AVAILABLE = False def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -35,11 +41,12 @@ def lambda_init_fn(depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) -class DifferentialAttentionBase(nn.Module): +class LlamaDifferentialAttentionBase(nn.Module): """Base class for differential attention implementations.""" def __init__(self, config: Any, layer_idx: int): super().__init__() + self.config = config self._init_config(config, layer_idx) self._init_projections() self._init_differential_params() @@ -59,9 +66,9 @@ def _init_config(self, config: Any, layer_idx: int): if config.split_heads: # Split heads mode - single projections - self.head_dim = config.hidden_size // config.num_attention_heads // 2 + self.head_dim = config.hidden_size // config.num_attention_heads # NOTE: This rounds down `base_num_heads / 2` as opposed to the original - # implementation, which asserts `self.base_num_heads` is even. + # implementation, which asserts `self.base_num_heads` is even self.heads_per_component = self.base_num_heads // 2 self.value_head_dim = 2 * self.head_dim else: @@ -110,36 +117,43 @@ def _init_differential_params(self): self.lambda_k2 = nn.Parameter( torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) - self.rotary_emb = LlamaRotaryEmbedding( - self.max_position_embeddings, self.head_dim, self.rope_theta - ) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) def _init_normalization(self, config): """Initialize normalization layers.""" sublayer_norm = getattr(config, "sublayer_norm", True) - self.subln = ( - LlamaRMSNorm(self.value_head_dim, eps=1e-5) - if sublayer_norm - else nn.Identity() - ) + if sublayer_norm: + self.subln = LlamaRMSNorm(self.value_head_dim, eps=config.rms_norm_eps) + else: + self.subln = nn.Identity() def _prepare_attention_inputs(self, hidden_states: torch.Tensor): """Prepare inputs for attention computation.""" bsz, q_len, _ = hidden_states.size() # Project and split - qp = self.q_proj(hidden_states) - kp = self.k_proj(hidden_states) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) - q1, q2 = qp.chunk(2, dim=-1) - k1, k2 = kp.chunk(2, dim=-1) + q1, q2 = q.chunk(2, dim=-1) + k1, k2 = k.chunk(2, dim=-1) # Reshape - q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2) + q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose( + 1, 2 + ) return q1, q2, k1, k2, v @@ -148,16 +162,16 @@ def _apply_rotary_embeddings( ): """Apply rotary embeddings to queries and keys.""" if position_embeddings is None: - if position_ids is None: - position_ids = torch.arange(q1.size(-2), device=q1.device) + LOG.warning( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) cos, sin = self.rotary_emb(q1, position_ids) else: cos, sin = position_embeddings - if self.split_heads: - cos, _ = cos.chunk(2, dim=2) - sin, _ = sin.chunk(2, dim=2) - q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) @@ -195,7 +209,7 @@ def _process_attention_output(self, attn, bsz, q_len): return self.o_proj(attn) -class LlamaDifferentialAttention(DifferentialAttentionBase): +class LlamaDifferentialAttention(LlamaDifferentialAttentionBase): """Standard implementation of differential attention.""" def forward( @@ -237,15 +251,16 @@ def forward( lambda_full = self._compute_lambda(q1) attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v) - attn = self._process_attention_output(attn, bsz, q_len) if output_attentions: - return attn, attn1 - lambda_full * attn2, past_key_value + attn_weights = attn1 - lambda_full * attn2 + attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1) + return attn, attn_weights, past_key_value return attn, None, past_key_value -class LlamaDifferentialSdpaAttention(DifferentialAttentionBase): +class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase): """SDPA-based implementation of differential attention.""" # pylint: disable=duplicate-code @@ -262,6 +277,11 @@ def forward( **kwargs, # pylint: disable=unused-argument ): if output_attentions: + LOG.warning( + "LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but " + + "`torch.nn.functional.scaled_dot_product_attention` does not support " + + "`output_attentions=True`. Falling back to the eager attention implementation." + ) return LlamaDifferentialAttention.forward( self, hidden_states, @@ -309,9 +329,18 @@ def forward( return attn, None, past_key_value -class LlamaDifferentialFlashAttention2(DifferentialAttentionBase): +class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase): """Flash Attention 2-based implementation of differential attention.""" + def __init__(self, *args, **kwargs): + if not FLASH_ATTENTION_AVAILABLE: + raise ImportError( + "LlamaDifferentialFlashAttention2 requires flash-attn library. " + "Please install with `pip install flash-attn --no-build-isolation`" + ) + + super().__init__(*args, **kwargs) + # pylint: disable=duplicate-code def forward( self, @@ -326,6 +355,11 @@ def forward( **kwargs, # pylint: disable=unused-argument ): if output_attentions: + LOG.warning( + "LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but " + + "flash attenion does not support `output_attentions=True`. Falling back " + + "to the eager attention implementation." + ) return LlamaDifferentialAttention.forward( self, hidden_states, diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py new file mode 100644 index 0000000000..5949707164 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -0,0 +1,370 @@ +"""Modeling for differential transformers.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, +) + +from .diff_attn import ( + LlamaDifferentialAttention, + LlamaDifferentialAttentionBase, + LlamaDifferentialFlashAttention2, + LlamaDifferentialSdpaAttention, +) + + +class LlamaDifferentialConfig(LlamaConfig): + """Configuration class for Differential LLaMA model.""" + + def __init__( + self, + split_heads: bool = False, + sublayer_norm: bool = True, + zero_init: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.split_heads = split_heads + self.sublayer_norm = sublayer_norm + self.zero_init = zero_init + self.architectures = ["LlamaDifferentialModel"] + self._attn_implementations = { + "eager": "differential_eager", + "sdpa": "differential_sdpa", + "flash_attention_2": "differential_flash_attention_2", + } + + +class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel): + """Base class for differential LLaMA models.""" + + config_class = LlamaDifferentialConfig + base_model_prefix = "llama_differential" + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LlamaDifferentialAttentionBase, LlamaModel)): + module.gradient_checkpointing = value + + +def lambda_init_fn(depth: int) -> float: + """Initialize lambda parameter based on layer depth.""" + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + +class LlamaDifferentialModel(LlamaDifferentialPreTrainedModel): + """Differential version of the LLaMA model.""" + + def __init__(self, config: LlamaDifferentialConfig): + super().__init__(config) + # Map attn implementations to classes + self.attn_implementation_to_class = { + "differential_eager": LlamaDifferentialAttention, + "differential_sdpa": LlamaDifferentialSdpaAttention, + "differential_flash_attention_2": LlamaDifferentialFlashAttention2, + } + + # Get correct attention implementation + attn_implementation = getattr(config, "_attn_implementation", "eager") + if attn_implementation in config._attn_implementations: + attn_implementation = config._attn_implementations[attn_implementation] + + self.attention_class = self.attn_implementation_to_class.get( + attn_implementation, LlamaDifferentialAttention + ) + + # Initialize model components + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, config.pad_token_id + ) + self.layers = nn.ModuleList( + [ + LlamaDifferentialDecoderLayer( + config=config, layer_idx=i, attention_class=self.attention_class + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Check if either input_ids or inputs_embeds is provided + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + if input_ids is not None: + batch_size, seq_length = input_ids.shape + device = input_ids.device + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Initialize past_key_values if needed + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + + # Create attention mask if not provided + if attention_mask is not None: + attention_mask = self._prepare_attention_mask( + attention_mask, (batch_size, seq_length), device + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # Initialize lists to store outputs + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_cache = () if use_cache else None + + for _, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)): + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_cache += (layer_outputs[-1],) # type: ignore + + if output_attentions: + all_self_attns += (layer_outputs[1],) # type: ignore + + # Add last hidden state + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _prepare_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + device: torch.device, + ) -> torch.Tensor: + """Prepare attention mask for computing attention.""" + # Create causal mask + # [batch_size, seq_length] -> [batch_size, 1, seq_length, seq_length] + combined_attention_mask = None + _, seq_length = input_shape + + if self.config.is_decoder: + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(1, seq_length, 1) + <= seq_ids[None, :, None] + ) + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1:] != (seq_length, seq_length): + causal_mask = causal_mask[:, :seq_length, :seq_length] + + # Extend attention mask + combined_attention_mask = ( + causal_mask[None, None, :, :] * attention_mask[:, None, None, :] + ) + else: + combined_attention_mask = attention_mask[:, None, None, :] + + return combined_attention_mask + + @classmethod + def from_llama( + cls, + llama_model: LlamaModel, + differential_config: Optional[LlamaDifferentialConfig] = None, + ) -> "LlamaDifferentialModel": + """Convert a standard LLaMA model to use differential attention.""" + if differential_config is None: + # pylint: disable=protected-access + differential_config = LlamaDifferentialConfig.from_pretrained( + llama_model.config._name_or_path + ) + + # Create new model + new_model = cls(differential_config) + + # Copy non-attention weights directly + new_model.embed_tokens.load_state_dict(llama_model.embed_tokens.state_dict()) + new_model.norm.load_state_dict(llama_model.norm.state_dict()) + + # Copy layer weights, handling attention layers specially + for new_layer, old_layer in zip(new_model.layers, llama_model.layers): + # Copy self-attention weights with special handling + if differential_config.split_heads: + # Split heads mode + new_layer.self_attn.q_proj.weight.data.copy_( + old_layer.self_attn.q_proj.weight.data + ) + new_layer.self_attn.k_proj.weight.data.copy_( + old_layer.self_attn.k_proj.weight.data + ) + else: + # Double projection mode - copy weights to positive components + new_layer.self_attn.q_proj.weight.data[ + : differential_config.hidden_size + ].copy_(old_layer.self_attn.q_proj.weight.data) + new_layer.self_attn.k_proj.weight.data[ + : differential_config.hidden_size + ].copy_(old_layer.self_attn.k_proj.weight.data) + + # Zero out relevant parameters for exact equivalence + if differential_config.zero_init: + old_kv_size = old_layer.self_attn.k_proj.weight.data.size(0) + new_layer.self_attn.q_proj.weight.data[ + new_layer.self_attn.hidden_size : + ] = 0 + new_layer.self_attn.k_proj.weight.data[old_kv_size:] = 0 + nn.init.zeros_(new_layer.self_attn.lambda_q1) + nn.init.zeros_(new_layer.self_attn.lambda_k1) + nn.init.zeros_(new_layer.self_attn.lambda_q2) + nn.init.zeros_(new_layer.self_attn.lambda_k2) + nn.init.zeros_(new_layer.self_attn.lambda_init) + + # Copy remaining weights + new_layer.self_attn.v_proj.load_state_dict( + old_layer.self_attn.v_proj.state_dict() + ) + new_layer.self_attn.o_proj.load_state_dict( + old_layer.self_attn.o_proj.state_dict() + ) + + # Copy MLP and layer norm weights + new_layer.mlp.load_state_dict(old_layer.mlp.state_dict()) + new_layer.input_layernorm.load_state_dict( + old_layer.input_layernorm.state_dict() + ) + new_layer.post_attention_layernorm.load_state_dict( + old_layer.post_attention_layernorm.state_dict() + ) + + return new_model + + +class LlamaDifferentialDecoderLayer(nn.Module): + """Custom decoder layer for diffrential Llama model.""" + + def __init__( + self, config: LlamaDifferentialConfig, layer_idx: int, attention_class + ): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = attention_class(config, layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Layer forward pass with differential attention. + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) # type: ignore + + if use_cache: + outputs += (present_key_value,) # type: ignore + + return outputs # type: ignore From 7d9ec2c77ecb465cd76bb95a99b1651b594385e0 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 23 Dec 2024 14:22:33 -0500 Subject: [PATCH 23/25] fix duplicate-code warnings --- src/axolotl/integrations/diff_transformer/modeling_diff_attn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index 5949707164..4b97bfe10a 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -96,6 +96,7 @@ def __init__(self, config: LlamaDifferentialConfig): ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # pylint: disable=duplicate-code def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -322,6 +323,7 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) + # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor, From 44e4b837e263dcffb5346433fefbb6b21877041c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 23 Dec 2024 20:40:55 -0500 Subject: [PATCH 24/25] updated custom modeling code --- .../integrations/convert_diff_transformer.py | 19 +- .../diff_transformer/modeling_diff_attn.py | 373 ++++-------------- 2 files changed, 90 insertions(+), 302 deletions(-) diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 360832dbb8..db4b0df4d4 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -15,7 +15,10 @@ from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer -from axolotl.integrations.diff_transformer.convert import convert_to_diff_attn +from axolotl.integrations.diff_transformer.modeling_diff_attn import ( + LlamaDifferentialConfig, + LlamaDifferentialForCausalLM, +) from axolotl.utils.yaml import dump_yaml_preserved_order LOG = logging.getLogger(__name__) @@ -86,13 +89,15 @@ def convert_diff_transformer(cfg, cli_args, config_path): + Fore.RESET ) try: - model = convert_to_diff_attn( - model=model, - zero_init=cli_args.zero_init, - sublayer_norm=cli_args.sublayer_norm, - split_heads=cli_args.split_heads, + LlamaDifferentialForCausalLM.from_llama( + model, + LlamaDifferentialConfig( + **model.config.__dict__, + zero_init=cli_args.zero_init, + sublayer_norm=cli_args.sublayer_norm, + split_heads=cli_args.split_heads, + ), ) - model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) raise diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index 4b97bfe10a..a3d31382da 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -1,18 +1,13 @@ """Modeling for differential transformers.""" -import math -from typing import List, Optional, Tuple, Union +from typing import Optional import torch -from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( - LlamaMLP, + LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel, - LlamaRMSNorm, ) from .diff_attn import ( @@ -56,210 +51,57 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value -def lambda_init_fn(depth: int) -> float: - """Initialize lambda parameter based on layer depth.""" - return 0.8 - 0.6 * math.exp(-0.3 * depth) +class LlamaDifferentialModel(LlamaModel): + """LlamaModel with differential attention.""" - -class LlamaDifferentialModel(LlamaDifferentialPreTrainedModel): - """Differential version of the LLaMA model.""" - - def __init__(self, config: LlamaDifferentialConfig): + def __init__(self, config): super().__init__(config) - # Map attn implementations to classes - self.attn_implementation_to_class = { - "differential_eager": LlamaDifferentialAttention, - "differential_sdpa": LlamaDifferentialSdpaAttention, - "differential_flash_attention_2": LlamaDifferentialFlashAttention2, - } - - # Get correct attention implementation - attn_implementation = getattr(config, "_attn_implementation", "eager") - if attn_implementation in config._attn_implementations: - attn_implementation = config._attn_implementations[attn_implementation] - - self.attention_class = self.attn_implementation_to_class.get( - attn_implementation, LlamaDifferentialAttention - ) - - # Initialize model components - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, config.pad_token_id - ) - self.layers = nn.ModuleList( - [ - LlamaDifferentialDecoderLayer( - config=config, layer_idx=i, attention_class=self.attention_class + # Replace standard attention with differential attention in each layer + for layer in self.layers: + attn_impl = config._attn_implementation or "eager" + if attn_impl == "eager": + layer.self_attn = LlamaDifferentialAttention(config, layer.layer_idx) + elif attn_impl == "sdpa": + layer.self_attn = LlamaDifferentialSdpaAttention( + config, layer.layer_idx + ) + elif attn_impl == "flash_attention_2": + layer.self_attn = LlamaDifferentialFlashAttention2( + config, layer.layer_idx ) - for i in range(config.num_hidden_layers) - ] - ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - # pylint: disable=duplicate-code - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # Check if either input_ids or inputs_embeds is provided - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - device = input_ids.device - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - device = inputs_embeds.device - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) - - # Initialize past_key_values if needed - if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - - # Create attention mask if not provided - if attention_mask is not None: - attention_mask = self._prepare_attention_mask( - attention_mask, (batch_size, seq_length), device - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - hidden_states = inputs_embeds - - # Initialize lists to store outputs - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_cache = () if use_cache else None - - for _, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)): - if output_hidden_states: - all_hidden_states += (hidden_states,) # type: ignore - - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_cache += (layer_outputs[-1],) # type: ignore - - if output_attentions: - all_self_attns += (layer_outputs[1],) # type: ignore - # Add last hidden state - hidden_states = self.norm(hidden_states) + @classmethod + def from_llama( + cls, model: LlamaModel, config: Optional[LlamaDifferentialConfig] = None + ) -> "LlamaDifferentialModel": + """Convert a LlamaModel to use differential attention.""" + if config is None: + config = LlamaDifferentialConfig(**model.config.__dict__) - if output_hidden_states: - all_hidden_states += (hidden_states,) # type: ignore + new_model = cls(config) + # Copy all weights except attention + new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict()) + new_model.norm.load_state_dict(model.norm.state_dict()) - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None + for new_layer, old_layer in zip(new_model.layers, model.layers): + # Copy everything except attention weights + new_layer.mlp.load_state_dict(old_layer.mlp.state_dict()) + new_layer.input_layernorm.load_state_dict( + old_layer.input_layernorm.state_dict() ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - def _prepare_attention_mask( - self, - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - device: torch.device, - ) -> torch.Tensor: - """Prepare attention mask for computing attention.""" - # Create causal mask - # [batch_size, seq_length] -> [batch_size, 1, seq_length, seq_length] - combined_attention_mask = None - _, seq_length = input_shape - - if self.config.is_decoder: - seq_ids = torch.arange(seq_length, device=device) - causal_mask = ( - seq_ids[None, None, :].repeat(1, seq_length, 1) - <= seq_ids[None, :, None] + new_layer.post_attention_layernorm.load_state_dict( + old_layer.post_attention_layernorm.state_dict() ) - causal_mask = causal_mask.to(attention_mask.dtype) - if causal_mask.shape[1:] != (seq_length, seq_length): - causal_mask = causal_mask[:, :seq_length, :seq_length] - - # Extend attention mask - combined_attention_mask = ( - causal_mask[None, None, :, :] * attention_mask[:, None, None, :] + # Handle attention weights + new_layer.self_attn.v_proj.load_state_dict( + old_layer.self_attn.v_proj.state_dict() ) - else: - combined_attention_mask = attention_mask[:, None, None, :] - - return combined_attention_mask - - @classmethod - def from_llama( - cls, - llama_model: LlamaModel, - differential_config: Optional[LlamaDifferentialConfig] = None, - ) -> "LlamaDifferentialModel": - """Convert a standard LLaMA model to use differential attention.""" - if differential_config is None: - # pylint: disable=protected-access - differential_config = LlamaDifferentialConfig.from_pretrained( - llama_model.config._name_or_path + new_layer.self_attn.o_proj.load_state_dict( + old_layer.self_attn.o_proj.state_dict() ) - # Create new model - new_model = cls(differential_config) - - # Copy non-attention weights directly - new_model.embed_tokens.load_state_dict(llama_model.embed_tokens.state_dict()) - new_model.norm.load_state_dict(llama_model.norm.state_dict()) - - # Copy layer weights, handling attention layers specially - for new_layer, old_layer in zip(new_model.layers, llama_model.layers): - # Copy self-attention weights with special handling - if differential_config.split_heads: - # Split heads mode + if config.split_heads: new_layer.self_attn.q_proj.weight.data.copy_( old_layer.self_attn.q_proj.weight.data ) @@ -267,106 +109,47 @@ def from_llama( old_layer.self_attn.k_proj.weight.data ) else: - # Double projection mode - copy weights to positive components - new_layer.self_attn.q_proj.weight.data[ - : differential_config.hidden_size - ].copy_(old_layer.self_attn.q_proj.weight.data) - new_layer.self_attn.k_proj.weight.data[ - : differential_config.hidden_size - ].copy_(old_layer.self_attn.k_proj.weight.data) - - # Zero out relevant parameters for exact equivalence - if differential_config.zero_init: - old_kv_size = old_layer.self_attn.k_proj.weight.data.size(0) - new_layer.self_attn.q_proj.weight.data[ - new_layer.self_attn.hidden_size : - ] = 0 - new_layer.self_attn.k_proj.weight.data[old_kv_size:] = 0 - nn.init.zeros_(new_layer.self_attn.lambda_q1) - nn.init.zeros_(new_layer.self_attn.lambda_k1) - nn.init.zeros_(new_layer.self_attn.lambda_q2) - nn.init.zeros_(new_layer.self_attn.lambda_k2) - nn.init.zeros_(new_layer.self_attn.lambda_init) + new_layer.self_attn.q_proj.weight.data[: config.hidden_size].copy_( + old_layer.self_attn.q_proj.weight.data + ) + new_layer.self_attn.k_proj.weight.data[: config.hidden_size].copy_( + old_layer.self_attn.k_proj.weight.data + ) - # Copy remaining weights - new_layer.self_attn.v_proj.load_state_dict( - old_layer.self_attn.v_proj.state_dict() - ) - new_layer.self_attn.o_proj.load_state_dict( - old_layer.self_attn.o_proj.state_dict() - ) - - # Copy MLP and layer norm weights - new_layer.mlp.load_state_dict(old_layer.mlp.state_dict()) - new_layer.input_layernorm.load_state_dict( - old_layer.input_layernorm.state_dict() - ) - new_layer.post_attention_layernorm.load_state_dict( - old_layer.post_attention_layernorm.state_dict() - ) + if config.zero_init: + # Zero out components as needed + with torch.no_grad(): + new_layer.self_attn.q_proj.weight.data[ + config.hidden_size : + ].zero_() + new_layer.self_attn.k_proj.weight.data[ + config.hidden_size : + ].zero_() + new_layer.self_attn.lambda_q1.zero_() + new_layer.self_attn.lambda_k1.zero_() + new_layer.self_attn.lambda_q2.zero_() + new_layer.self_attn.lambda_k2.zero_() + new_layer.self_attn.lambda_init.zero_() return new_model -class LlamaDifferentialDecoderLayer(nn.Module): - """Custom decoder layer for diffrential Llama model.""" - - def __init__( - self, config: LlamaDifferentialConfig, layer_idx: int, attention_class - ): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = attention_class(config, layer_idx) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - # pylint: disable=duplicate-code - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Layer forward pass with differential attention. - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) +class LlamaDifferentialForCausalLM(LlamaForCausalLM): + """LlamaForCausalLM with differential attention.""" - if output_attentions: - outputs += (self_attn_weights,) # type: ignore - - if use_cache: - outputs += (present_key_value,) # type: ignore + def __init__(self, config): + super().__init__(config) + self.model = LlamaDifferentialModel(config) - return outputs # type: ignore + @classmethod + def from_llama( + cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None + ) -> "LlamaDifferentialForCausalLM": + """Convert a LlamaForCausalLM to use differential attention.""" + if config is None: + config = LlamaDifferentialConfig(**model.config.__dict__) + + new_model = cls(config) + new_model.model = LlamaDifferentialModel.from_llama(model.model, config) + new_model.lm_head.load_state_dict(model.lm_head.state_dict()) + return new_model From 6945bdd1dbcbce11da88765bc94c641a274893bb Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 24 Dec 2024 05:30:46 +0000 Subject: [PATCH 25/25] progress on modeling code --- .../integrations/convert_diff_transformer.py | 48 ++-- .../integrations/diff_transformer/README.md | 3 + .../integrations/diff_transformer/__init__.py | 13 +- .../integrations/diff_transformer/args.py | 3 + .../integrations/diff_transformer/convert.py | 135 --------- .../diff_transformer/diff_attn.py | 36 ++- .../diff_transformer/modeling_diff_attn.py | 67 +++-- src/axolotl/utils/models.py | 21 +- .../test_convert_diff_transformer.py | 258 +++++++++--------- 9 files changed, 242 insertions(+), 342 deletions(-) delete mode 100644 src/axolotl/integrations/diff_transformer/convert.py diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index db4b0df4d4..28cc87bbd3 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -26,34 +26,27 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"): """Run test inference and return generation time""" - try: - inputs = tokenizer(prompt, return_tensors="pt") - inputs = { - k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() - } - - start = time() - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=20, - num_beams=1, - do_sample=False, - pad_token_id=tokenizer.pad_token_id, - use_cache=False, - ) - elapsed = time() - start + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()} + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) + elapsed = time() - start - generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) - LOG.info("Prompt: %s", prompt) - LOG.info("Generated: %s", generated_text) - LOG.info("Generation time: %.2fs", elapsed) + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + LOG.info("Prompt: %s", prompt) + LOG.info("Generated: %s", generated_text) + LOG.info("Generation time: %.2fs", elapsed) - return elapsed, generated_text - - except Exception as exc: - LOG.error("Inference failed: %s", str(exc)) - raise + return elapsed, generated_text def convert_diff_transformer(cfg, cli_args, config_path): @@ -89,7 +82,7 @@ def convert_diff_transformer(cfg, cli_args, config_path): + Fore.RESET ) try: - LlamaDifferentialForCausalLM.from_llama( + model = LlamaDifferentialForCausalLM.from_llama( model, LlamaDifferentialConfig( **model.config.__dict__, @@ -98,6 +91,7 @@ def convert_diff_transformer(cfg, cli_args, config_path): split_heads=cli_args.split_heads, ), ) + model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) raise diff --git a/src/axolotl/integrations/diff_transformer/README.md b/src/axolotl/integrations/diff_transformer/README.md index 14473f7537..a683fdf1df 100644 --- a/src/axolotl/integrations/diff_transformer/README.md +++ b/src/axolotl/integrations/diff_transformer/README.md @@ -7,4 +7,7 @@ plugins: - axolotl.integrations.diff_transformer.DifferentialTransformerPlugin diff_attention: true +diff_attn_zero_init: false +diff_attn_sublayer_norm: true +diff_attn_split_heads: false ``` diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py index 70459e0266..461ede4fd4 100644 --- a/src/axolotl/integrations/diff_transformer/__init__.py +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -8,18 +8,7 @@ class DifferentialTransformerPlugin(BasePlugin): - """ - Plugin for differential transformer integration with Axolotl. - """ + """Plugin for differential transformer integration with Axolotl.""" def get_input_args(self): return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" - - def pre_model_load(self, cfg): - """Apply differential attention patch before model loading if enabled.""" - if cfg.diff_attention: - from axolotl.monkeypatch.attention.differential import ( - patch_llama_attention_classes, - ) - - patch_llama_attention_classes() diff --git a/src/axolotl/integrations/diff_transformer/args.py b/src/axolotl/integrations/diff_transformer/args.py index 47c1fe1104..332c0b4aa0 100644 --- a/src/axolotl/integrations/diff_transformer/args.py +++ b/src/axolotl/integrations/diff_transformer/args.py @@ -12,3 +12,6 @@ class DifferentialTransformerArgs(BaseModel): """Input args for differential transformer.""" diff_attention: Optional[bool] = None + diff_attn_zero_init: Optional[bool] = None + diff_attn_sublayer_norm: Optional[bool] = None + diff_attn_split_heads: Optional[bool] = None diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py deleted file mode 100644 index 298a0232ed..0000000000 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Differential attention conversion logic for a huggingface pre-trained model.""" -import logging -from typing import Union - -import torch -from torch import nn -from transformers import PreTrainedModel -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaFlashAttention2, - LlamaSdpaAttention, -) - -from .diff_attn import ( - LlamaDifferentialAttention, - LlamaDifferentialFlashAttention2, - LlamaDifferentialSdpaAttention, -) - -logger = logging.getLogger(__name__) - -ATTENTION_MAPPING = { - LlamaAttention: LlamaDifferentialAttention, - LlamaSdpaAttention: LlamaDifferentialSdpaAttention, - LlamaFlashAttention2: LlamaDifferentialFlashAttention2, -} - - -def copy_attention_weights( - old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2], - new_attn: Union[ - LlamaDifferentialAttention, - LlamaDifferentialSdpaAttention, - LlamaDifferentialFlashAttention2, - ], - zero_init: bool = False, -) -> None: - """ - Copy weights from old attention layer to new differential attention layer. - Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence - to original attention mechanism. - """ - # For Q projection (Q1 and Q2) - new_q = torch.empty_like(new_attn.q_proj.weight.data) - new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1 - if zero_init: - new_q[new_attn.hidden_size :] = 0 - else: - nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1) - new_attn.q_proj.weight.data.copy_(new_q) - - # For K projection (K1 and K2) - old_kv_size = old_attn.k_proj.weight.data.size(0) - new_k = torch.empty_like(new_attn.k_proj.weight.data) - new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1 - if zero_init: - new_k[old_kv_size:] = 0 - else: - nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1) - new_attn.k_proj.weight.data.copy_(new_k) - - # For V projection (single V) - new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data) - - # Output projection remains the same - new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data) - - # Zero out lambda parameters for exact equivalence - if zero_init: - nn.init.zeros_(new_attn.lambda_q1) - nn.init.zeros_(new_attn.lambda_k1) - nn.init.zeros_(new_attn.lambda_q2) - nn.init.zeros_(new_attn.lambda_k2) - nn.init.zeros_(new_attn.lambda_init) - - logger.debug( - "Copied positive attention weights from %s to %s", - type(old_attn).__name__, - type(new_attn).__name__, - ) - - -def convert_to_diff_attn( - model: PreTrainedModel, - zero_init: bool = False, - sublayer_norm: bool = True, - split_heads: bool = True, -) -> PreTrainedModel: - """Convert a pre-trained model's attention layers to differential attention""" - layer_idx = 0 - - # Set sublayer norm as config on the model. - model.config.sublayer_norm = sublayer_norm - model.config.split_heads = split_heads - - def convert_module(module): - nonlocal layer_idx - - # Iterate through module children, convert any attn layers to diff attn - for name, child in module.named_children(): - child_class_name = type(child).__name__ - - if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]: - # Find matching attention class by name - for orig_class, diff_class in ATTENTION_MAPPING.items(): - if orig_class.__name__ == child_class_name: - attention_class = diff_class - break - - layer_type = type(child).__name__ - logger.info( - f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}" - ) - - # Create new diff attn layer - new_attention = attention_class( - config=module.config if hasattr(module, "config") else model.config, - layer_idx=layer_idx, - ) - - # Copy weights from old attention to new attention - new_attention.to(child.q_proj.weight.device) - if not split_heads: - copy_attention_weights(child, new_attention, zero_init=zero_init) - - # Replace the layer - setattr(module, name, new_attention) - layer_idx += 1 - elif len(list(child.children())) > 0: - convert_module(child) - - convert_module(model) - logger.info(f"Converted {layer_idx} attention layers to differential attention") - - return model diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index cccb0adebd..5ae5034646 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -56,8 +56,10 @@ def _init_config(self, config: Any, layer_idx: int): """Initialize configuration parameters.""" self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // config.num_attention_heads self.base_num_heads = config.num_attention_heads self.base_num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads self.layer_idx = layer_idx self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta @@ -66,15 +68,15 @@ def _init_config(self, config: Any, layer_idx: int): if config.split_heads: # Split heads mode - single projections - self.head_dim = config.hidden_size // config.num_attention_heads # NOTE: This rounds down `base_num_heads / 2` as opposed to the original # implementation, which asserts `self.base_num_heads` is even self.heads_per_component = self.base_num_heads // 2 + self.kv_heads_per_component = self.base_num_kv_heads // 2 self.value_head_dim = 2 * self.head_dim else: # Double projection mode - self.head_dim = config.hidden_size // config.num_attention_heads self.heads_per_component = self.base_num_heads + self.kv_heads_per_component = self.base_num_kv_heads self.value_head_dim = self.head_dim def _init_projections(self): @@ -90,14 +92,22 @@ def _init_projections(self): self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2 ) - self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False) + self.q_proj = nn.Linear( + self.hidden_size, q_out_dim, bias=self.config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, k_out_dim, bias=self.config.attention_bias + ) self.v_proj = nn.Linear( self.hidden_size, self.hidden_size // self.base_num_heads * self.base_num_kv_heads, - bias=False, + bias=self.config.attention_bias, + ) + self.o_proj = nn.Linear( + self.base_num_heads * self.head_dim, + self.hidden_size, + bias=self.config.attention_bias, ) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) def _init_differential_params(self): """Initialize differential attention parameters.""" @@ -145,13 +155,13 @@ def _prepare_attention_inputs(self, hidden_states: torch.Tensor): q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( 1, 2 ) - k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose( 1, 2 ) - k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose( 1, 2 ) - v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose( + v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose( 1, 2 ) @@ -184,10 +194,10 @@ def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs): k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) k1, k2 = k.unbind(dim=1) - # Repeat KV heads - k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) - k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) - v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) + # Repeat KV heads to match number of query heads + k1 = repeat_kv(k1, self.num_key_value_groups) + k2 = repeat_kv(k2, self.num_key_value_groups) + v = repeat_kv(v, self.num_key_value_groups) return k1, k2, v diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index a3d31382da..b84dfcd166 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -56,19 +56,16 @@ class LlamaDifferentialModel(LlamaModel): def __init__(self, config): super().__init__(config) + # Replace standard attention with differential attention in each layer - for layer in self.layers: + for idx, layer in enumerate(self.layers): attn_impl = config._attn_implementation or "eager" if attn_impl == "eager": - layer.self_attn = LlamaDifferentialAttention(config, layer.layer_idx) + layer.self_attn = LlamaDifferentialAttention(config, idx) elif attn_impl == "sdpa": - layer.self_attn = LlamaDifferentialSdpaAttention( - config, layer.layer_idx - ) + layer.self_attn = LlamaDifferentialSdpaAttention(config, idx) elif attn_impl == "flash_attention_2": - layer.self_attn = LlamaDifferentialFlashAttention2( - config, layer.layer_idx - ) + layer.self_attn = LlamaDifferentialFlashAttention2(config, idx) @classmethod def from_llama( @@ -78,7 +75,21 @@ def from_llama( if config is None: config = LlamaDifferentialConfig(**model.config.__dict__) + # Validate head counts if using split heads mode + if config.split_heads: + if config.num_attention_heads % 2 != 0: + raise ValueError( + f"Number of attention heads ({config.num_attention_heads}) must be even " + "when using split_heads=True" + ) + if config.num_key_value_heads % 2 != 0: + raise ValueError( + f"Number of key/value heads ({config.num_key_value_heads}) must be even " + "when using split_heads=True" + ) + new_model = cls(config) + # Copy all weights except attention new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict()) new_model.norm.load_state_dict(model.norm.state_dict()) @@ -97,34 +108,28 @@ def from_llama( new_layer.self_attn.v_proj.load_state_dict( old_layer.self_attn.v_proj.state_dict() ) + print(old_layer.self_attn.o_proj.weight.shape) new_layer.self_attn.o_proj.load_state_dict( old_layer.self_attn.o_proj.state_dict() ) - if config.split_heads: - new_layer.self_attn.q_proj.weight.data.copy_( - old_layer.self_attn.q_proj.weight.data - ) - new_layer.self_attn.k_proj.weight.data.copy_( - old_layer.self_attn.k_proj.weight.data - ) - else: - new_layer.self_attn.q_proj.weight.data[: config.hidden_size].copy_( + # Get the original projection sizes + old_q_size = old_layer.self_attn.q_proj.weight.size(0) + old_k_size = old_layer.self_attn.k_proj.weight.size(0) + + if not config.split_heads: + new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_( old_layer.self_attn.q_proj.weight.data ) - new_layer.self_attn.k_proj.weight.data[: config.hidden_size].copy_( + new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_( old_layer.self_attn.k_proj.weight.data ) if config.zero_init: # Zero out components as needed with torch.no_grad(): - new_layer.self_attn.q_proj.weight.data[ - config.hidden_size : - ].zero_() - new_layer.self_attn.k_proj.weight.data[ - config.hidden_size : - ].zero_() + new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_() + new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_() new_layer.self_attn.lambda_q1.zero_() new_layer.self_attn.lambda_k1.zero_() new_layer.self_attn.lambda_q2.zero_() @@ -149,7 +154,21 @@ def from_llama( if config is None: config = LlamaDifferentialConfig(**model.config.__dict__) + # Validate head counts if using split heads mode + if config.split_heads: + if config.num_attention_heads % 2 != 0: + raise ValueError( + f"Number of attention heads ({config.num_attention_heads}) must be even " + "when using split_heads=True" + ) + if config.num_key_value_heads % 2 != 0: + raise ValueError( + f"Number of key/value heads ({config.num_key_value_heads}) must be even " + "when using split_heads=True" + ) + new_model = cls(config) new_model.model = LlamaDifferentialModel.from_llama(model.model, config) new_model.lm_head.load_state_dict(model.lm_head.state_dict()) + return new_model diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 37cbc08713..2c4d2513d3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -710,11 +710,30 @@ def set_attention_config(self) -> None: """ sample packing uses custom FA2 patch """ + # if self.cfg.flash_attention: + # if not self.cfg.sample_packing and self.cfg.s2_attention: + # pass + + # self.model_kwargs["attn_implementation"] = "flash_attention_2" + # self.model_config._attn_implementation = ( # pylint: disable=protected-access + # "flash_attention_2" + # ) + # elif self.cfg.sdp_attention: + # self.model_kwargs["attn_implementation"] = "sdpa" + # self.model_config._attn_implementation = ( # pylint: disable=protected-access + # "sdpa" + # ) + # elif self.cfg.eager_attention: + # self.model_kwargs["attn_implementation"] = "eager" + # self.model_config._attn_implementation = ( # pylint: disable=protected-access + # "eager" + # ) + if self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - if self.cfg.differentiaion: + if self.cfg.diff_attention: self.model_kwargs[ "attn_implementation" ] = "differential_flash_attention_2" diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py index 02939ee1ca..e616a8ef12 100644 --- a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -15,135 +15,133 @@ from axolotl.common.cli import ConvertDiffTransformerCliArgs -@pytest.mark.usefixtures("base_config", "cli_runner") -class TestDiffTransformer: - """Tests for convert-diff-transformer CLI command""" - - def test_cli_validation(self, cli_runner): - # Test missing config file - result = cli_runner.invoke(cli, ["convert-diff-transformer"]) - assert result.exit_code != 0 - assert "Error: Missing argument 'CONFIG'." in result.output - - # Test non-existent config file - result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - - def test_basic_execution(self, cli_runner, tmp_path: Path, base_config): - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - with patch( - "axolotl.cli.integrations.convert_diff_transformer.do_cli" - ) as mock_do_cli: - result = cli_runner.invoke( - cli, ["convert-diff-transformer", str(config_path)] - ) - assert result.exit_code == 0 - - mock_do_cli.assert_called_once() - assert mock_do_cli.call_args.kwargs["config"] == str(config_path) - - def test_conversion_cli_basic(self, tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs() - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert not debug_info - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - def test_conversion_cli_debug(self, tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert not debug_info["generations_match"] - assert not debug_info["match_expected"] - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - def test_conversion_cli_reproduce(self, tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False - ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - @pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] +def test_cli_validation(cli_runner): + # Test missing config file + result = cli_runner.invoke(cli, ["convert-diff-transformer"]) + assert result.exit_code != 0 + assert "Error: Missing argument 'CONFIG'." in result.output + + # Test non-existent config file + result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + +def test_basic_execution(cli_runner, tmp_path: Path, base_config): + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + with patch( + "axolotl.cli.integrations.convert_diff_transformer.do_cli" + ) as mock_do_cli: + result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)]) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + + +def test_conversion_cli_basic(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs() + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_cli_debug(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info["generations_match"] + assert not debug_info["match_expected"] + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_cli_reproduce(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False ) - def test_conversion_cli_repoduce_attentions( - self, tmp_path: Path, base_config, attention: Optional[str] - ): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False - ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - @pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) +def test_conversion_cli_repoduce_attentions( + tmp_path: Path, base_config, attention: Optional[str] +): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False ) - def test_conversion_cli_split_heads( - self, tmp_path: Path, base_config, attention: str - ): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is False - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) +def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is False + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists()