From d37fc25d568bbcfb36772d9eae4539d8deac1bd2 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 15 Dec 2023 15:31:46 -0800 Subject: [PATCH] Refactor launcher user arg parsing (#4824) Splitting work from #4769 because we are still debugging transformers integration issues. Parsing was broken for user arguments (see #4795). Additionally, parsing of user arguments is tricky and there are lots of edge cases. For example: #4660, #4716, #3967. I've attempted to accommodate all of the possible types of string inputs and added unit tests. --- deepspeed/launcher/multinode_runner.py | 14 ++++-- deepspeed/launcher/runner.py | 4 -- tests/unit/launcher/test_user_args.py | 64 ++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 tests/unit/launcher/test_user_args.py diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 730146f5bcd2..07c70f3cc6c5 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -56,13 +56,21 @@ def __init__(self, args, world_info_base64): def backend_exists(self): return shutil.which('pdsh') + def parse_user_args(self): + processed_args = [] + for arg in self.args.user_args: + # With pdsh, if we are passing a string as an argument, it will get + # split on whitespace. To avoid this and support strings that + # contain '"', we do this extra processing step: + if " " in arg: + arg = '"{}"'.format(arg.replace('"', '\\"')) + processed_args.append(arg) + return processed_args + @property def name(self): return "pdsh" - def parse_user_args(self): - return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args)) - def get_cmd(self, environment, active_resources): environment['PDSH_RCMD_TYPE'] = 'ssh' if self.args.ssh_port is not None: # only specify ssh port if it is specified diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index a7fa2b5053e5..99ebc9771e41 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -12,7 +12,6 @@ import os import re import sys -import shlex import json import base64 import argparse @@ -389,9 +388,6 @@ def parse_num_nodes(str_num_nodes: str, elastic_training: bool): def main(args=None): args = parse_args(args) - # For when argparse interprets remaining args as a single string - args.user_args = shlex.split(" ".join(list(map(lambda x: x if x.startswith("-") else f'"{x}"', args.user_args)))) - if args.elastic_training: assert args.master_addr != "", "Master Addr is required when elastic training is enabled" diff --git a/tests/unit/launcher/test_user_args.py b/tests/unit/launcher/test_user_args.py new file mode 100644 index 000000000000..99afd0f2cfa7 --- /dev/null +++ b/tests/unit/launcher/test_user_args.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import subprocess + +from deepspeed.accelerator import get_accelerator + +if not get_accelerator().is_available(): + pytest.skip("only supported in accelerator environments.", allow_module_level=True) + +user_arg_test_script = """import argparse +parser = argparse.ArgumentParser() +parser.add_argument("--prompt", type=str) +parser.add_argument("--local_rank", type=int, default=0) +parser.add_argument("--world_size", type=int, default=1) +args = parser.parse_args() +print("ARG PARSE SUCCESS") +""" + + +@pytest.fixture(scope="function") +def user_script_fp(tmpdir): + script_fp = tmpdir.join("user_arg_test.py") + with open(script_fp, "w") as f: + f.write(user_arg_test_script) + return script_fp + + +@pytest.fixture(scope="function") +def cmd(user_script_fp, prompt, multi_node): + if multi_node: + cmd = ("deepspeed", "--force_multi", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt) + else: + cmd = ("deepspeed", "--num_nodes", "1", "--num_gpus", "1", user_script_fp, "--prompt", prompt) + return cmd + + +@pytest.mark.parametrize("prompt", [ + '''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""", + '''I'm going to tell them "DeepSpeed is the best"''' +]) +@pytest.mark.parametrize("multi_node", [True, False]) +def test_user_args(cmd): + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}" + + +def test_bash_string_args(tmpdir, user_script_fp): + bash_script = f""" + ARGS="--prompt 'DeepSpeed is the best'" + echo ${{ARGS}}|xargs deepspeed --num_nodes 1 --num_gpus 1 {user_script_fp} + """ + + bash_fp = tmpdir.join("bash_script.sh") + with open(bash_fp, "w") as f: + f.write(bash_script) + + p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}"