Skip to content

Commit

Permalink
fix: handle sharegpt dataset missing (#2035)
Browse files Browse the repository at this point in the history
* fix: handle sharegpt dataset missing

* fix: explanation

* feat: add test
  • Loading branch information
NanoCode012 authored Nov 12, 2024
1 parent 3931a42 commit 9f1cf9b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,12 @@ def deprecate_sharegpt_datasets(cls, datasets):
if not ds_cfg.get("type"):
continue

if ds_cfg["type"].startswith("sharegpt"):
ds_type = ds_cfg["type"]
# skip if it's a dict (for custom user instruction prompt)
if isinstance(ds_type, dict):
continue

if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
raise ValueError(
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_validation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,59 @@ def _check_config():
)

_check_config()

def test_dataset_sharegpt_deprecation(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "sharegpt",
"conversation": "chatml",
}
],
}
)

# Check sharegpt deprecation is raised
with pytest.raises(ValueError, match=r".*type: sharegpt.*` is deprecated.*"):
validate_config(cfg)

# Check that deprecation is not thrown for non-str type
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": {
"field_instruction": "instruction",
"field_output": "output",
"field_system": "system",
"format": "<|user|> {instruction} {input} <|model|>",
"no_input_format": "<|user|> {instruction} <|model|>",
"system_prompt": "",
},
}
],
}
)

validate_config(cfg)

# Check that deprecation is not thrown for non-sharegpt type
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
}
)

validate_config(cfg)

0 comments on commit 9f1cf9b

Please sign in to comment.