From 63429028d562a5c955acba05c200d6cdfb71c2fa Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Fri, 15 Sep 2023 12:00:20 -0700 Subject: [PATCH] Remove max_mds_writer_workers arg. It is not necessary for writing to local folder. --- scripts/data_prep/convert_text_to_mds.py | 20 ++------------------ tests/test_convert_text_to_mds.py | 1 - 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index b5d95498a6..5e37da639a 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -30,12 +30,6 @@ def parse_args() -> Namespace: description= 'Convert text files into MDS format, optionally concatenating and tokenizing', ) - parser.add_argument( - '--max_mds_writer_workers', - type=int, - default=64, - help='The maximum number of workers to use for MDS writing', - ) parser.add_argument( '--output_folder', type=str, @@ -158,7 +152,6 @@ def get_task_args( bos_text: str, no_wrap: bool, compression: str, - max_mds_writer_workers: int, ) -> Iterable: """Get download_and_convert arguments split across n_groups. @@ -175,7 +168,6 @@ def get_task_args( bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing - max_mds_writer_workers (int): number of workers for MDS writing """ num_objects = len(object_names) objs_per_group = math.ceil(num_objects / n_groups) @@ -191,7 +183,6 @@ def get_task_args( bos_text, no_wrap, compression, - max_mds_writer_workers, ) @@ -213,7 +204,6 @@ def download_and_convert( bos_text: str, no_wrap: bool, compression: str, - max_mds_writer_workers: int, ): """Downloads and converts text fies to MDS format. @@ -227,7 +217,6 @@ def download_and_convert( bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples compression (str): The compression algorithm to use for MDS writing - max_mds_writer_workers (int): number of workers for MDS writing """ object_store = maybe_create_object_store_from_uri(input_folder) @@ -255,7 +244,6 @@ def download_and_convert( log.info('Converting to MDS format...') with MDSWriter(out=output_folder, columns=columns, - max_workers=max_mds_writer_workers, compression=compression) as out: for sample in tqdm(dataset): out.write(sample) @@ -342,7 +330,6 @@ def convert_text_to_mds( eos_text: str, bos_text: str, no_wrap: bool, - max_mds_writer_workers: int, compression: str, processes: int, args_str: str, @@ -358,7 +345,6 @@ def convert_text_to_mds( eos_text (str): Textend to append to each example to separate concatenated samples bos_text (str): Text to prepend to each example to separate concatenated samples no_wrap: (bool): Whether to let text examples wrap across multiple training examples - max_mds_writer_workers (int): number of workers for MDS writing compression (str): The compression algorithm to use for MDS writing processes (int): The number of processes to use. args_str (str): String representation of the arguments @@ -388,8 +374,7 @@ def convert_text_to_mds( # Download and convert the text files in parallel args = get_task_args(object_names, local_output_folder, input_folder, processes, tokenizer_name, concat_tokens, eos_text, - bos_text, no_wrap, compression, - max_mds_writer_workers) + bos_text, no_wrap, compression) with ProcessPoolExecutor(max_workers=processes) as executor: list(executor.map(download_and_convert_starargs, args)) @@ -398,7 +383,7 @@ def convert_text_to_mds( else: download_and_convert(object_names, local_output_folder, input_folder, tokenizer_name, concat_tokens, eos_text, bos_text, - no_wrap, compression, max_mds_writer_workers) + no_wrap, compression) # Write a done file with the args and object names write_done_file(local_output_folder, args_str, object_names) @@ -449,7 +434,6 @@ def _args_str(original_args: Namespace) -> str: eos_text=args.eos_text, bos_text=args.bos_text, no_wrap=args.no_wrap, - max_mds_writer_workers=args.max_mds_writer_workers, compression=args.compression, processes=args.processes, reprocess=args.reprocess, diff --git a/tests/test_convert_text_to_mds.py b/tests/test_convert_text_to_mds.py index 82e177b92d..2d4878ebbb 100644 --- a/tests/test_convert_text_to_mds.py +++ b/tests/test_convert_text_to_mds.py @@ -71,7 +71,6 @@ def _call_convert_text_to_mds(processes: int, tokenizer_name: str, eos_text='', bos_text='', no_wrap=False, - max_mds_writer_workers=1, compression='zstd', processes=processes, args_str='Namespace()',