Skip to content

Commit

Permalink
Remove max_mds_writer_workers arg. It is not necessary for writing to…
Browse files Browse the repository at this point in the history
… local folder.
  • Loading branch information
irenedea committed Sep 15, 2023
1 parent a8a2645 commit 6342902
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 19 deletions.
20 changes: 2 additions & 18 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ def parse_args() -> Namespace:
description=
'Convert text files into MDS format, optionally concatenating and tokenizing',
)
parser.add_argument(
'--max_mds_writer_workers',
type=int,
default=64,
help='The maximum number of workers to use for MDS writing',
)
parser.add_argument(
'--output_folder',
type=str,
Expand Down Expand Up @@ -158,7 +152,6 @@ def get_task_args(
bos_text: str,
no_wrap: bool,
compression: str,
max_mds_writer_workers: int,
) -> Iterable:
"""Get download_and_convert arguments split across n_groups.
Expand All @@ -175,7 +168,6 @@ def get_task_args(
bos_text (str): Text to prepend to each example to separate concatenated samples
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
compression (str): The compression algorithm to use for MDS writing
max_mds_writer_workers (int): number of workers for MDS writing
"""
num_objects = len(object_names)
objs_per_group = math.ceil(num_objects / n_groups)
Expand All @@ -191,7 +183,6 @@ def get_task_args(
bos_text,
no_wrap,
compression,
max_mds_writer_workers,
)


Expand All @@ -213,7 +204,6 @@ def download_and_convert(
bos_text: str,
no_wrap: bool,
compression: str,
max_mds_writer_workers: int,
):
"""Downloads and converts text fies to MDS format.
Expand All @@ -227,7 +217,6 @@ def download_and_convert(
bos_text (str): Text to prepend to each example to separate concatenated samples
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
compression (str): The compression algorithm to use for MDS writing
max_mds_writer_workers (int): number of workers for MDS writing
"""
object_store = maybe_create_object_store_from_uri(input_folder)

Expand Down Expand Up @@ -255,7 +244,6 @@ def download_and_convert(
log.info('Converting to MDS format...')
with MDSWriter(out=output_folder,
columns=columns,
max_workers=max_mds_writer_workers,
compression=compression) as out:
for sample in tqdm(dataset):
out.write(sample)
Expand Down Expand Up @@ -342,7 +330,6 @@ def convert_text_to_mds(
eos_text: str,
bos_text: str,
no_wrap: bool,
max_mds_writer_workers: int,
compression: str,
processes: int,
args_str: str,
Expand All @@ -358,7 +345,6 @@ def convert_text_to_mds(
eos_text (str): Textend to append to each example to separate concatenated samples
bos_text (str): Text to prepend to each example to separate concatenated samples
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
max_mds_writer_workers (int): number of workers for MDS writing
compression (str): The compression algorithm to use for MDS writing
processes (int): The number of processes to use.
args_str (str): String representation of the arguments
Expand Down Expand Up @@ -388,8 +374,7 @@ def convert_text_to_mds(
# Download and convert the text files in parallel
args = get_task_args(object_names, local_output_folder, input_folder,
processes, tokenizer_name, concat_tokens, eos_text,
bos_text, no_wrap, compression,
max_mds_writer_workers)
bos_text, no_wrap, compression)
with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_and_convert_starargs, args))

Expand All @@ -398,7 +383,7 @@ def convert_text_to_mds(
else:
download_and_convert(object_names, local_output_folder, input_folder,
tokenizer_name, concat_tokens, eos_text, bos_text,
no_wrap, compression, max_mds_writer_workers)
no_wrap, compression)

# Write a done file with the args and object names
write_done_file(local_output_folder, args_str, object_names)
Expand Down Expand Up @@ -449,7 +434,6 @@ def _args_str(original_args: Namespace) -> str:
eos_text=args.eos_text,
bos_text=args.bos_text,
no_wrap=args.no_wrap,
max_mds_writer_workers=args.max_mds_writer_workers,
compression=args.compression,
processes=args.processes,
reprocess=args.reprocess,
Expand Down
1 change: 0 additions & 1 deletion tests/test_convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def _call_convert_text_to_mds(processes: int, tokenizer_name: str,
eos_text='',
bos_text='',
no_wrap=False,
max_mds_writer_workers=1,
compression='zstd',
processes=processes,
args_str='Namespace()',
Expand Down

0 comments on commit 6342902

Please sign in to comment.