diff --git a/examples/pytorch/text-classification/run_classification.py b/examples/pytorch/text-classification/run_classification.py index 5decef3656b9a4..ff05b78cb538ec 100755 --- a/examples/pytorch/text-classification/run_classification.py +++ b/examples/pytorch/text-classification/run_classification.py @@ -133,6 +133,10 @@ class DataTrainingArguments: ) }, ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} ) @@ -573,6 +577,7 @@ def preprocess_function(examples): raw_datasets = raw_datasets.map( preprocess_function, batched=True, + num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on dataset", )