From d4b77860da1dd668532a0928f75c485343a2ac53 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 22 Dec 2024 14:21:40 +0700 Subject: [PATCH 1/2] fix: drop long seq even if not sample packing --- src/axolotl/utils/data/sft.py | 7 ++++-- src/axolotl/utils/trainer.py | 46 +++++++++++++++-------------------- 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 286e5f2d70..70543d7c80 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -1,10 +1,10 @@ """data handling specific to SFT""" import functools -import logging from pathlib import Path from typing import List, Optional, Tuple, Union +from accelerate.logging import get_logger from datasets import ( Dataset, DatasetDict, @@ -51,10 +51,11 @@ from axolotl.utils.distributed import is_local_main_process, zero_first from axolotl.utils.trainer import ( calculate_total_num_steps, + drop_long_seq_in_dataset, process_datasets_for_packing, ) -LOG = logging.getLogger("axolotl") +LOG = get_logger("axolotl") @retry_on_request_exceptions(max_retries=3, delay=5) @@ -482,6 +483,8 @@ def for_d_in_datasets(dataset_configs): else: LOG.debug("NOT shuffling merged datasets") + dataset = drop_long_seq_in_dataset(dataset, cfg) + if cfg.sample_packing and not cfg.skip_prepare_dataset: dataset, _ = process_datasets_for_packing(cfg, dataset, None) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..77a232129d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -178,18 +178,33 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): ) -def process_datasets_for_packing(cfg, train_dataset, eval_dataset): +def drop_long_seq_in_dataset(dataset, cfg): drop_long = partial( drop_long_seq, sequence_len=cfg.sequence_len, - min_sequence_len=cfg.min_sample_len or 2, + min_sequence_len=cfg.min_sequence_len, ) - min_input_len = np.min(get_dataset_lengths(train_dataset)) + min_input_len = np.min(get_dataset_lengths(dataset)) LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) - max_input_len = np.max(get_dataset_lengths(train_dataset)) + max_input_len = np.max(get_dataset_lengths(dataset)) LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) + prior_len = len(dataset) + dataset = dataset.filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", + ) + dropped = prior_len - len(dataset) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from dataset") + + return dataset + + +def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if cfg.model_config_type == "mamba": LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") @@ -203,29 +218,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") - prior_len = len(train_dataset) - train_dataset = train_dataset.filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - dropped = prior_len - len(train_dataset) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from train dataset") - - if eval_dataset: - prior_len = len(eval_dataset) - eval_dataset = eval_dataset.filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - dropped = prior_len - len(eval_dataset) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from eval dataset") - # drop samples with where the number of elements with labels not equal to -100 is zero def drop_no_trainable_tokens(sample): return np.sum(np.array(sample["labels"]) != -100) > 0 From 0a83b887d3e550c09cd6c3f4461390f879442cd6 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 22 Dec 2024 23:30:37 +0700 Subject: [PATCH 2/2] fix: logging import --- src/axolotl/utils/data/sft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 70543d7c80..adfd1bd021 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -1,10 +1,10 @@ """data handling specific to SFT""" import functools +import logging from pathlib import Path from typing import List, Optional, Tuple, Union -from accelerate.logging import get_logger from datasets import ( Dataset, DatasetDict, @@ -55,7 +55,7 @@ process_datasets_for_packing, ) -LOG = get_logger("axolotl") +LOG = logging.getLogger("axolotl.utils.data.sft") @retry_on_request_exceptions(max_retries=3, delay=5)