diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 23929f9b85..7aff80e50f 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -801,6 +801,7 @@ def build_from_hf( split: str, safe_load: bool = False, max_seq_len: int = 2048, + mapping_fn: Callable = tokenize_formatted_example, preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, target_prompts: str = DEFAULT_TARGET_PROMPTS, @@ -824,6 +825,8 @@ def build_from_hf( max_seq_len (int): The maximum length of sequences in the batch. See :class:`Seq2SeqFinetuningCollator` docstring for details. + mapping_fn (Callable): The mapping function to use for mapping the data + examples. preprocessing_fn (Callable, optional): The preprocessing function to use for formatting the data examples. tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for tokenizing @@ -930,11 +933,11 @@ def build_from_hf( def dataset_mapper(example: dict): if preprocessing_fn is not None: - return tokenize_formatted_example( + return mapping_fn( preprocessing_fn(example), tokenizer, ) - return tokenize_formatted_example(example, tokenizer) + return mapping_fn(example, tokenizer) detected_cpu_count = os.cpu_count() or 1 detected_cpus_with_margin = detected_cpu_count - 8