Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert-diff-transformer CLI command / codepath #2197

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5b4d027
Basic evaluate CLI command / codepath (#2188)
djsaunde Dec 16, 2024
1e49a88
initial diff attn layer / model conversion implementation (support fo…
djsaunde Dec 11, 2024
8c4ff51
Adding script for doing conversion; fixes and updates
djsaunde Dec 12, 2024
8264c62
adding CLI command for convert-diff-transformer
djsaunde Dec 12, 2024
4bdbb2f
training fixes, patching, minor cleanup
djsaunde Dec 13, 2024
60a1668
various improvemnents
djsaunde Dec 13, 2024
32f1b3f
various improvemnents
djsaunde Dec 13, 2024
c1968ed
fix model save / load logic
djsaunde Dec 17, 2024
dbeea75
pre-commit fix
djsaunde Dec 17, 2024
81d9ff4
moving monkeypatch
djsaunde Dec 17, 2024
c74a290
differential flash attention 2; cleanup
djsaunde Dec 17, 2024
12d14cc
duplicate code ignore
djsaunde Dec 17, 2024
a6b5a5e
convert-differential-transformer test coverage
djsaunde Dec 17, 2024
6a9af88
plugin implementation
djsaunde Dec 18, 2024
b7294d4
fixes post-rebase
djsaunde Dec 18, 2024
c57d21e
isolating problematic test
djsaunde Dec 18, 2024
513b262
adding split_heads argument for retaining original (Q, K) dimensionan…
djsaunde Dec 18, 2024
313265f
moving tests around for flash_attn install
djsaunde Dec 18, 2024
53b4d80
removing extra pytest xdist args
djsaunde Dec 19, 2024
9262124
adding yaml dumper preserving input config format
djsaunde Dec 20, 2024
a1a3f1d
refactor and fixing test isolation issues
djsaunde Dec 21, 2024
c6def27
added modeling code; cleanup + refactor
Dec 23, 2024
7d9ec2c
fix duplicate-code warnings
Dec 23, 2024
44e4b83
updated custom modeling code
Dec 24, 2024
6945bdd
progress on modeling code
djsaunde Dec 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,6 @@ out/

# vim
*.swp

# symlinked to axolotl-artifacts in docker containers
outputs
1 change: 0 additions & 1 deletion cicd/cicd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"

pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
4 changes: 2 additions & 2 deletions cicd/multigpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code

import os
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/cli/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import logging
from pathlib import Path
from typing import Union
from typing import Dict, Union

import fire
from dotenv import load_dotenv
Expand All @@ -23,7 +23,7 @@
LOG = logging.getLogger("axolotl.cli.evaluate")


def do_evaluate(cfg, cli_args) -> None:
def do_evaluate(cfg, cli_args) -> Dict[str, float]:
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
Expand All @@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
Expand Down
Empty file.
208 changes: 208 additions & 0 deletions src/axolotl/cli/integrations/convert_diff_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""CLI to convert a transformers model's attention layers to differential attention layers."""

import logging
import warnings
from pathlib import Path
from time import time
from typing import Union

import fire
import torch
import yaml
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser

from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
)
from axolotl.utils.yaml import dump_yaml_preserved_order

LOG = logging.getLogger(__name__)


def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}

start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)

return elapsed, generated_text


def convert_diff_transformer(cfg, cli_args, config_path):
debug_info = {}

# Load model and tokenizer
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model.to(cfg.device, dtype=cfg.torch_dtype)

# Log original model info
LOG.info(
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
model.config.hidden_size,
model.config.num_attention_heads,
)

# Test original model
if cli_args.debug:
LOG.info("Testing original model...")
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
model, tokenizer
)

# Convert attention
LOG.info("Converting to differential attention...")
if cli_args.split_heads and cli_args.zero_init:
LOG.warning(
Fore.YELLOW
+ "Warning: Using split_heads with zero_init is not recommended; "
+ "split_heads will preclude the effects of zero_init"
+ Fore.RESET
)
try:
model = LlamaDifferentialForCausalLM.from_llama(
model,
LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
),
)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise

# Test converted model
if cli_args.debug:
LOG.info("Testing converted model...")
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
model, tokenizer
)

# Save if requested
if cfg.output_dir:
# Save model and tokenizer
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)

# Modify config to reflect new path / differential attention
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
LOG.info("Saving updated config to %s", output_config_path)

with open(config_path, "r", encoding="utf-8") as file:
modified_cfg = yaml.safe_load(file) or {}

modified_cfg["base_model"] = cfg.output_dir
modified_cfg["diff_attention"] = True
plugin_class = (
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
)
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]

# Write out the updated axolotl config while preserving original ordering / formatting
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")

if cli_args.debug:
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
+ Fore.RESET
)

if debug_info["orig_text"] == debug_info["conv_text"]:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ "Model generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
debug_info["generations_match"] = True
else:
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['conv_text']}\n"
+ "*" * 50
+ "\n"
)
debug_info["generations_match"] = False

if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
debug_info["match_expected"] = True
else:
LOG.info(
Fore.YELLOW
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)
debug_info["match_expected"] = False

return model, debug_info


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()

cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)

convert_diff_transformer(cfg, cli_args, config)


if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)
23 changes: 22 additions & 1 deletion src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
build_command,
fetch_from_github,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.common.cli import (
ConvertDiffTransformerCliArgs,
EvaluateCliArgs,
PreprocessCliArgs,
TrainerCliArgs,
)
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig

Expand Down Expand Up @@ -77,6 +82,9 @@ def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
Expand Down Expand Up @@ -240,6 +248,19 @@ def merge_lora(
do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

from axolotl.cli.integrations.convert_diff_transformer import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
Expand Down
12 changes: 10 additions & 2 deletions src/axolotl/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def decorator(function):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type

if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
Expand All @@ -44,6 +43,7 @@ def decorator(function):
default=field.default,
help=field.metadata.get("description"),
)(function)

return function

return decorator
Expand All @@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = field.annotation
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)

# NOTE: defaults are handled by the pydantic model config classes.
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
Expand All @@ -66,6 +73,7 @@ def decorator(function):
function = click.option(
option_name, default=None, help=field.description
)(function)

return function

return decorator
Expand Down
Loading
Loading