diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 915267786f..b9f588c284 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -515,6 +515,22 @@ def is_valid_ift_example( return True +def _get_num_processes() -> int: + """Get the number of processes to use for dataset processing.""" + detected_cpu_count = os.cpu_count() or 1 + detected_cpus_with_margin = detected_cpu_count - 8 + num_proc = max(1, detected_cpus_with_margin) + + # Check if the user has set the MAX_NUM_PROC environment variable + # which caps the number of processes used for dataset processing. + if 'MAX_NUM_PROC' in os.environ: + max_num_proc_env = int(os.environ['MAX_NUM_PROC']) + if max_num_proc_env < num_proc: + num_proc = max_num_proc_env + + return num_proc + + class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. @@ -960,18 +976,16 @@ def dataset_mapper(example: dict): ) return mapping_fn(example, tokenizer) - detected_cpu_count = os.cpu_count() or 1 - detected_cpus_with_margin = detected_cpu_count - 8 - num_cpus_to_use = max(1, detected_cpus_with_margin) - if len(dataset) < num_cpus_to_use: - num_cpus_to_use = 1 + num_proc = _get_num_processes() + if len(dataset) < num_proc: + num_proc = 1 columns_to_remove = list(dataset[0].keys()) tokenized_dataset = dataset.map( dataset_mapper, batched=False, remove_columns=columns_to_remove, - num_proc=num_cpus_to_use, + num_proc=num_proc, desc='Tokenizing dataset', )