Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert-diff-transformer CLI command / codepath #2197

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

djsaunde
Copy link
Contributor

@djsaunde djsaunde commented Dec 17, 2024

Description

This PR implements the differential attention layer from the Differential Transformer paper.

Motivation and Context

We wanted to add this attention implementation to axolotl so users can swap out the existing attention layers in their models for this more performant version. We matched the official implementation details as closely as possible, while adopting it to play nicely with the transformers attention implementations.

Since we were focused on being able to convert existing LLMs to having these differential attention layers, we wanted a way to not degrade the performance of the (possibly pre-trained) LLM while doing so.

To this end, the conversion process doubles the dimensionality of the query and key projections (since the differential attention requires both a positive and negative component of the attention) and (optionally; pass --zero-init) initializes the weights of the negative component to zero, while copying over the weights from the original attention modules to the positive components.

When doing this, the converted network computes the same function as the original (pass --debug to confirm this), but may suffer from a vanishing gradient problem. The default behavior is thus to initialize the weights of the negative components of the differential attention layers to 0-centered normally distributed values with a small variance.

Relevant links:

How has this been tested?

SmolLM2-135m on A40 Runpod instance on this feature branch. Workflow was:

  • Convert the model to use either eager or SDPA differential attention
    • With and without --zero-init and --debug flags for sanity checking exact model conversion (completions, logits, losses)
  • Run new axolotl evaluate command on the small mhenrichsen/alpaca_2k_test dataset with both the original and converted model and check that their evaluation metrics match

For example:

$ axolotl convert-diff-transformer ../configs/smollm.yaml --output-dir ../converted-model --zero-init --debug
...
[2024-12-17 05:15:26,910] [INFO] [axolotl.cli.convert_attention.convert_diff_transformer:75] [PID:94590] [RANK:0] Converting 
to differential attention...                                                                                                 
[2024-12-17 05:15:26,910] [INFO] [axolotl.integrations.diff_transformer.convert.convert_module:97] [PID:94590] [RANK:0] Conve
rting attention layer 0: LlamaSdpaAttention to LlamaDifferentialSdpaAttention                                                
[2024-12-17 05:15:26,921] [DEBUG] [axolotl.integrations.diff_transformer.convert.copy_attention_weights:64] [PID:94590] [RANK
:0] Copied positive attention weights from LlamaSdpaAttention to LlamaDifferentialSdpaAttention                              
[2024-12-17 05:15:26,921] [INFO] [axolotl.integrations.diff_transformer.convert.convert_module:97] [PID:94590] [RANK:0] Conve
rting attention layer 1: LlamaSdpaAttention to LlamaDifferentialSdpaAttention                                                
[2024-12-17 05:15:26,930] [DEBUG] [axolotl.integrations.diff_transformer.convert.copy_attention_weights:64] [PID:94590] [RANK
:0] Copied positive attention weights from LlamaSdpaAttention to LlamaDifferentialSdpaAttention
...
ANK:0] Converted 30 attention layers to differential attention
[2024-12-17 05:15:27,181] [INFO] [axolotl.cli.convert_attention.convert_diff_transformer:85] [PID:94590] [RANK:0] Testing con
verted model...
[2024-12-17 05:15:27,785] [INFO] [axolotl.cli.convert_attention.test_inference:43] [PID:94590] [RANK:0] Prompt: The quick brown fox                                                                                                                       
[2024-12-17 05:15:28,280] [INFO] [axolotl.cli.convert_attention.convert_diff_transformer:121] [PID:94590] [RANK:0] Generations match!
Model generation:
**************************************************
The quick brown fox jumps over the lazy dog

The quick brown fox jumps over the lazy dog.

The
**************************************************

Types of changes

  • axolotl.integrations.diff_transformer module, which implements the differential attention layers for the Llama LLM architecture and for various attention implementations (eager, SDPA, Flash Attention 2), and
  • axolotl.cli.integrations.convert_diff_transformer module (and updates to axolotl.cli.main), which implements the convert-diff-transformer CLI command, and
  • Monkeypatch in axolotl.cli.integrations.convert_diff_transformer.patches (to be moved) for updating LLAMA_ATTENTION_CLASSES constant in transformers.models.llama.modeling_llama.

TODO

  • Test coverage
  • Add Flash Attention 2 implementation
  • Move monkey patch
  • Refactor conversion module as plugin
  • Add conversion with same-sized Q, K projections
  • Experiments to demonstrate value
    • Blog post

@djsaunde djsaunde self-assigned this Dec 17, 2024
outputs Outdated Show resolved Hide resolved
@djsaunde djsaunde force-pushed the diff-transformer branch 2 times, most recently from f2c37e7 to 2717b97 Compare December 20, 2024 20:41
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could similarly have a function to normalize any config yaml file to have some expected ordering / formatting.

@ehartford
Copy link
Collaborator

I thought differential transformer requires model architecture change and modeling code change? Does this somehow automatically implement a modeling.py for the model?

@djsaunde
Copy link
Contributor Author

djsaunde commented Dec 21, 2024

I thought differential transformer requires model architecture change and modeling code change? Does this somehow automatically implement a modeling.py for the model?

Good question. I've implemented a monkeypatch in src/axolotl/monkeypatch/attention/differential.py that updates the PreTrainedModel._autoset_attn_implementation function to be aware of the differential attention implementation. I think it's a bit of a hack, though, so using custom modeling code might be a good change before merge. Happy to hear your thoughts / feedback!

As for the architecture change, we have src/axolotl/cli/integrations/convert_diff_transformer.py which does the actually swapping of (llama only, for now) attention layers with differential attention in the model.

@ehartford
Copy link
Collaborator

monkey patch only works in the context of Axolotl - we will need a modeling.py to make inference work properly in the wild (transformers, TGI, vllm, etc) right? (If I understand correctly)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants