diff --git a/xmanager/contrib/executor_selector.py b/xmanager/contrib/executor_selector.py index cd2abd9..9b7b6ce 100644 --- a/xmanager/contrib/executor_selector.py +++ b/xmanager/contrib/executor_selector.py @@ -47,40 +47,20 @@ executor = executor_fn(**kwargs) """ -import enum from typing import Callable, List, Optional, Union -from absl import flags from xmanager import xm +from xmanager import xm_flags from xmanager import xm_local -class XMLaunchMode(enum.Enum): - """Specifies an executor to run an experiment.""" - - VERTEX = 'vertex' - LOCAL = 'local' - INTERACTIVE = 'interactive' - - -_XM_LAUNCH_MODE = flags.DEFINE_enum_class( - 'xm_launch_mode', - XMLaunchMode.VERTEX, - XMLaunchMode, - 'How to launch the experiment. Supports local and interactive execution, ' - + 'launch on ' - + - 'Vertex.', -) - - -def launch_mode() -> XMLaunchMode: - return _XM_LAUNCH_MODE.value +def launch_mode() -> xm_flags.XMLaunchMode: + return xm_flags.XM_LAUNCH_MODE.value def create_experiment( experiment_title: Optional[str] = None, - mode: Optional[XMLaunchMode] = None, + mode: Optional[xm_flags.XMLaunchMode] = None, ) -> xm.Experiment: """Creates an experiment depending on the launch mode. @@ -99,9 +79,9 @@ def create_experiment( mode = launch_mode() if mode in ( - XMLaunchMode.LOCAL, - XMLaunchMode.INTERACTIVE, - XMLaunchMode.VERTEX, + xm_flags.XMLaunchMode.LOCAL, + xm_flags.XMLaunchMode.INTERACTIVE, + xm_flags.XMLaunchMode.VERTEX, ): # TODO: add import here? return xm_local.create_experiment(experiment_title) @@ -128,7 +108,7 @@ def setup_local(*args, **kwargs_in) -> xm_local.Local: def get_executor( - mode: Optional[XMLaunchMode] = None, + mode: Optional[xm_flags.XMLaunchMode] = None, ) -> Callable[..., xm.Executor]: """Select an `xm.Executor` specialization depending on the launch mode. @@ -145,8 +125,11 @@ def get_executor( if mode is None: mode = launch_mode() - if mode == XMLaunchMode.VERTEX: + if mode == xm_flags.XMLaunchMode.VERTEX: return xm_local.Caip - if mode == XMLaunchMode.LOCAL or mode == XMLaunchMode.INTERACTIVE: - return _local_executor(mode == XMLaunchMode.INTERACTIVE) + if ( + mode == xm_flags.XMLaunchMode.LOCAL + or mode == xm_flags.XMLaunchMode.INTERACTIVE + ): + return _local_executor(mode == xm_flags.XMLaunchMode.INTERACTIVE) raise ValueError(f'Unknown launch mode: {mode}') diff --git a/xmanager/contrib/gcs.py b/xmanager/contrib/gcs.py index 87616ae..9dcd56e 100644 --- a/xmanager/contrib/gcs.py +++ b/xmanager/contrib/gcs.py @@ -26,16 +26,7 @@ import os from absl import app -from absl import flags - -_GCS_PATH = flags.DEFINE_string( - 'xm_gcs_path', - None, - ( - 'A GCS directory within a bucket to store output ' - '(in gs://bucket/directory format).' - ), -) +from xmanager import xm_flags _GS_PREFIX = 'gs://' _GCS_PREFIX = '/gcs/' @@ -64,17 +55,17 @@ def get_gcs_path_or_fail(project_name: str) -> str: If the --xm_gcs_path flag is empty, or contains invalid value, raise an error. Otherwise, returns a flag value. """ - if not _GCS_PATH.value: + if not xm_flags.GCS_PATH.value: raise app.UsageError( '--xm_gcs_path is missing. Suggestion: ' + f'--xm_gcs_path={suggestion(project_name)}' ) - elif not is_gcs_path(_GCS_PATH.value): + elif not is_gcs_path(xm_flags.GCS_PATH.value): raise app.UsageError( '--xm_gcs_path not in gs://bucket/directory or /gcs/path format. ' + f'Suggestion: --xm_gcs_path={suggestion(project_name)}' ) - return str(_GCS_PATH.value) + return str(xm_flags.GCS_PATH.value) def is_gs_path(path: str) -> bool: diff --git a/xmanager/xm_flags.py b/xmanager/xm_flags.py index 35852e5..5a18417 100644 --- a/xmanager/xm_flags.py +++ b/xmanager/xm_flags.py @@ -13,6 +13,7 @@ # limitations under the License. """XManager Flags.""" +import enum from absl import flags # -------------------- xm_local -------------------- @@ -55,3 +56,33 @@ BAZEL_COMMAND = flags.DEFINE_string( 'xm_bazel_command', 'bazel', 'A command that runs Bazel.', ) + +# -------------------- contrib -------------------- + +GCS_PATH = flags.DEFINE_string( + 'xm_gcs_path', + None, + ( + 'A GCS directory within a bucket to store output ' + '(in gs://bucket/directory format).' + ), +) + + +class XMLaunchMode(enum.Enum): + """Specifies an executor to run an experiment.""" + + VERTEX = 'vertex' + LOCAL = 'local' + INTERACTIVE = 'interactive' + + +XM_LAUNCH_MODE = flags.DEFINE_enum_class( + 'xm_launch_mode', + XMLaunchMode.VERTEX, + XMLaunchMode, + 'How to launch the experiment. Supports local and interactive execution, ' + + 'launch on ' + + + 'Vertex.', +)