diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index d91278fed..360832dbb 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -1,4 +1,5 @@ -"""CLI to convert a transformers model's attns to diff attns.""" +"""CLI to convert a transformers model's attention layers to differential attention layers.""" + import logging import warnings from pathlib import Path @@ -127,6 +128,7 @@ def convert_diff_transformer(cfg, cli_args, config_path): else: modified_cfg["plugins"] = [plugin_class] + # Write out the updated axolotl config while preserving original ordering / formatting dump_yaml_preserved_order( data=modified_cfg, reference_yaml_path=config_path, diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index ea3b91c0c..ebe098ca6 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -12,14 +12,12 @@ from axolotl.utils.models import load_model, load_tokenizer configure_logging() -LOG = logging.getLogger("axolotl.common.cli") +LOG = logging.getLogger(__name__) @dataclass class PreprocessCliArgs: - """ - dataclass with arguments for preprocessing only - """ + """dataclass with arguments for preprocessing only""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) @@ -30,9 +28,7 @@ class PreprocessCliArgs: @dataclass class TrainerCliArgs: - """ - dataclass with various non-training arguments - """ + """dataclass with various non-training arguments""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) @@ -45,9 +41,7 @@ class TrainerCliArgs: @dataclass class EvaluateCliArgs: - """ - dataclass with various evaluation arguments - """ + """dataclass with various evaluation arguments""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) @@ -56,9 +50,7 @@ class EvaluateCliArgs: @dataclass class ConvertDiffTransformerCliArgs: - """ - dataclass with arguments for convert-diff-transformer CLI - """ + """dataclass with arguments for convert-diff-transformer CLI""" debug: bool = field(default=False) zero_init: bool = field(default=False) diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 5c10f2137..d942567d5 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -98,9 +98,13 @@ def convert_module(module): # Iterate through module children, convert any attn layers to diff attn for name, child in module.named_children(): - if isinstance(child, tuple(ATTENTION_MAPPING.keys())): - # Choose appropriate differential attention class - attention_class = ATTENTION_MAPPING[type(child)] + child_class_name = type(child).__name__ + if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]: + # Find matching attention class by name + for orig_class, diff_class in ATTENTION_MAPPING.items(): + if orig_class.__name__ == child_class_name: + attention_class = diff_class + break layer_type = type(child).__name__ logger.info( diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index edf532c41..a8d7536dd 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -21,7 +21,6 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" batch_size, n_kv_heads, slen, head_dim = x.shape if n_rep == 1: return x @@ -249,6 +248,7 @@ def forward( class LlamaDifferentialSdpaAttention(DifferentialAttentionBase): """SDPA-based implementation of differential attention.""" + # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor, @@ -312,6 +312,7 @@ def forward( class LlamaDifferentialFlashAttention2(DifferentialAttentionBase): """Flash Attention 2-based implementation of differential attention.""" + # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor, diff --git a/src/axolotl/utils/yaml.py b/src/axolotl/utils/yaml.py index 107afafcf..c5c9e74ae 100644 --- a/src/axolotl/utils/yaml.py +++ b/src/axolotl/utils/yaml.py @@ -84,6 +84,11 @@ class OrderedDumper(yaml.SafeDumper): """Custom YAML dumper that maintains dictionary order.""" +def represent_none(self, _): + """Represent None values as empty fields.""" + return self.represent_scalar("tag:yaml.org,2002:null", "") + + def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any: """Custom representer for dictionaries that maintains order.""" return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) @@ -121,7 +126,8 @@ def dump_yaml_preserved_order( # Reorder the data ordered_data = reorder_dict(data, tracker.structure) - # Register the custom representer + # Register the custom representers + OrderedDumper.add_representer(type(None), represent_none) OrderedDumper.add_representer(dict, ordered_dict_representer) OrderedDumper.add_representer(OrderedDict, ordered_dict_representer) diff --git a/tests/e2e/integrations/convert_diff_transformer/conftest.py b/tests/e2e/integrations/convert_diff_transformer/conftest.py index d4ffeb759..3964df052 100644 --- a/tests/e2e/integrations/convert_diff_transformer/conftest.py +++ b/tests/e2e/integrations/convert_diff_transformer/conftest.py @@ -4,7 +4,7 @@ from click.testing import CliRunner -@pytest.fixture() +@pytest.fixture(scope="class") def base_config(): """Basic config for testing.""" return { @@ -26,6 +26,6 @@ def base_config(): } -@pytest.fixture +@pytest.fixture(scope="class") def cli_runner(): return CliRunner() diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py index e616a8ef1..02939ee1c 100644 --- a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -15,133 +15,135 @@ from axolotl.common.cli import ConvertDiffTransformerCliArgs -def test_cli_validation(cli_runner): - # Test missing config file - result = cli_runner.invoke(cli, ["convert-diff-transformer"]) - assert result.exit_code != 0 - assert "Error: Missing argument 'CONFIG'." in result.output - - # Test non-existent config file - result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - - -def test_basic_execution(cli_runner, tmp_path: Path, base_config): - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - with patch( - "axolotl.cli.integrations.convert_diff_transformer.do_cli" - ) as mock_do_cli: - result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)]) - assert result.exit_code == 0 - - mock_do_cli.assert_called_once() - assert mock_do_cli.call_args.kwargs["config"] == str(config_path) - - -def test_conversion_cli_basic(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs() - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert not debug_info - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -def test_conversion_cli_debug(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert not debug_info["generations_match"] - assert not debug_info["match_expected"] - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -def test_conversion_cli_reproduce(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False +@pytest.mark.usefixtures("base_config", "cli_runner") +class TestDiffTransformer: + """Tests for convert-diff-transformer CLI command""" + + def test_cli_validation(self, cli_runner): + # Test missing config file + result = cli_runner.invoke(cli, ["convert-diff-transformer"]) + assert result.exit_code != 0 + assert "Error: Missing argument 'CONFIG'." in result.output + + # Test non-existent config file + result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + def test_basic_execution(self, cli_runner, tmp_path: Path, base_config): + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + with patch( + "axolotl.cli.integrations.convert_diff_transformer.do_cli" + ) as mock_do_cli: + result = cli_runner.invoke( + cli, ["convert-diff-transformer", str(config_path)] + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + + def test_conversion_cli_basic(self, tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs() + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + def test_conversion_cli_debug(self, tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info["generations_match"] + assert not debug_info["match_expected"] + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + def test_conversion_cli_reproduce(self, tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + @pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -@pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] -) -def test_conversion_cli_repoduce_attentions( - tmp_path: Path, base_config, attention: Optional[str] -): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False + def test_conversion_cli_repoduce_attentions( + self, tmp_path: Path, base_config, attention: Optional[str] + ): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + @pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -@pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] -) -def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is False - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() + def test_conversion_cli_split_heads( + self, tmp_path: Path, base_config, attention: str + ): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is False + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists()