Skip to content

Commit

Permalink
Add env var for configuring the maximum number of processes to use
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 22, 2024
1 parent 6448e4e commit 7898e8c
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,22 @@ def is_valid_ift_example(
return True


def _get_num_processes() -> int:
"""Get the number of processes to use for dataset processing."""
detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_proc = max(1, detected_cpus_with_margin)

# Check if the user has set the MAX_NUM_PROC environment variable
# which caps the number of processes used for dataset processing.
if 'MAX_NUM_PROC' in os.environ:
max_num_proc_env = int(os.environ['MAX_NUM_PROC'])
if max_num_proc_env < num_proc:
num_proc = max_num_proc_env

return num_proc


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Expand Down Expand Up @@ -960,18 +976,16 @@ def dataset_mapper(example: dict):
)
return mapping_fn(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)
if len(dataset) < num_cpus_to_use:
num_cpus_to_use = 1
num_proc = _get_num_processes()
if len(dataset) < num_proc:
num_proc = 1

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
num_proc=num_proc,
desc='Tokenizing dataset',
)

Expand Down

0 comments on commit 7898e8c

Please sign in to comment.