From 9f1cf9b17c3761ea806720ba44dac44def3c1451 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 12 Nov 2024 12:51:37 +0700 Subject: [PATCH] fix: handle sharegpt dataset missing (#2035) * fix: handle sharegpt dataset missing * fix: explanation * feat: add test --- .../config/models/input/v0_4_1/__init__.py | 7 ++- tests/test_validation_dataset.py | 56 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 8d2065bdb8..dc7693f06c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -789,7 +789,12 @@ def deprecate_sharegpt_datasets(cls, datasets): if not ds_cfg.get("type"): continue - if ds_cfg["type"].startswith("sharegpt"): + ds_type = ds_cfg["type"] + # skip if it's a dict (for custom user instruction prompt) + if isinstance(ds_type, dict): + continue + + if isinstance(ds_type, str) and ds_type.startswith("sharegpt"): raise ValueError( "`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead." ) diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index 7e288f8165..14f9d34627 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -234,3 +234,59 @@ def _check_config(): ) _check_config() + + def test_dataset_sharegpt_deprecation(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "chat_template": "chatml", + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "sharegpt", + "conversation": "chatml", + } + ], + } + ) + + # Check sharegpt deprecation is raised + with pytest.raises(ValueError, match=r".*type: sharegpt.*` is deprecated.*"): + validate_config(cfg) + + # Check that deprecation is not thrown for non-str type + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": { + "field_instruction": "instruction", + "field_output": "output", + "field_system": "system", + "format": "<|user|> {instruction} {input} <|model|>", + "no_input_format": "<|user|> {instruction} <|model|>", + "system_prompt": "", + }, + } + ], + } + ) + + validate_config(cfg) + + # Check that deprecation is not thrown for non-sharegpt type + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + } + ) + + validate_config(cfg)