Skip to content

Commit

Permalink
Move contrib flags into xmanager/xm_flags.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610885145
Change-Id: I5dd96f3cd26d42501fb1c9dd2caecfd357cd2137
GitOrigin-RevId: 2e54439fba84c65c57b4e4dac5fb23b1bea2510a
  • Loading branch information
DeepMind Team authored and alpiccioni committed Dec 4, 2024
1 parent 455c637 commit 5e19287
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 49 deletions.
31 changes: 0 additions & 31 deletions xm_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""XManager Flags."""

import enum
from absl import flags

# -------------------- xm_local --------------------
Expand Down Expand Up @@ -56,33 +55,3 @@
BAZEL_COMMAND = flags.DEFINE_string(
'xm_bazel_command', 'bazel', 'A command that runs Bazel.'
)

# -------------------- contrib --------------------

GCS_PATH = flags.DEFINE_string(
'xm_gcs_path',
None,
(
'A GCS directory within a bucket to store output '
'(in gs://bucket/directory format).'
),
)


class XMLaunchMode(enum.Enum):
"""Specifies an executor to run an experiment."""

VERTEX = 'vertex'
LOCAL = 'local'
INTERACTIVE = 'interactive'


XM_LAUNCH_MODE = flags.DEFINE_enum_class(
'xm_launch_mode',
XMLaunchMode.VERTEX,
XMLaunchMode,
'How to launch the experiment. Supports local and interactive execution, '
+ 'launch on '
+
'Vertex.',
)
45 changes: 31 additions & 14 deletions xmanager/contrib/executor_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,40 @@
executor = executor_fn(**kwargs)
"""

import enum
from typing import Callable, List, Optional, Union

from absl import flags
from xmanager import xm
from xmanager import xm_flags
from xmanager import xm_local


def launch_mode() -> xm_flags.XMLaunchMode:
return xm_flags.XM_LAUNCH_MODE.value
class XMLaunchMode(enum.Enum):
"""Specifies an executor to run an experiment."""

VERTEX = 'vertex'
LOCAL = 'local'
INTERACTIVE = 'interactive'


_XM_LAUNCH_MODE = flags.DEFINE_enum_class(
'xm_launch_mode',
XMLaunchMode.VERTEX,
XMLaunchMode,
'How to launch the experiment. Supports local and interactive execution, '
+ 'launch on '
+
'Vertex.',
)


def launch_mode() -> XMLaunchMode:
return _XM_LAUNCH_MODE.value


def create_experiment(
experiment_title: Optional[str] = None,
mode: Optional[xm_flags.XMLaunchMode] = None,
mode: Optional[XMLaunchMode] = None,
) -> xm.Experiment:
"""Creates an experiment depending on the launch mode.
Expand All @@ -79,9 +99,9 @@ def create_experiment(
mode = launch_mode()

if mode in (
xm_flags.XMLaunchMode.LOCAL,
xm_flags.XMLaunchMode.INTERACTIVE,
xm_flags.XMLaunchMode.VERTEX,
XMLaunchMode.LOCAL,
XMLaunchMode.INTERACTIVE,
XMLaunchMode.VERTEX,
):
# TODO: add import here?
return xm_local.create_experiment(experiment_title)
Expand All @@ -108,7 +128,7 @@ def setup_local(*args, **kwargs_in) -> xm_local.Local:


def get_executor(
mode: Optional[xm_flags.XMLaunchMode] = None,
mode: Optional[XMLaunchMode] = None,
) -> Callable[..., xm.Executor]:
"""Select an `xm.Executor` specialization depending on the launch mode.
Expand All @@ -125,11 +145,8 @@ def get_executor(
if mode is None:
mode = launch_mode()

if mode == xm_flags.XMLaunchMode.VERTEX:
if mode == XMLaunchMode.VERTEX:
return xm_local.Caip
if (
mode == xm_flags.XMLaunchMode.LOCAL
or mode == xm_flags.XMLaunchMode.INTERACTIVE
):
return _local_executor(mode == xm_flags.XMLaunchMode.INTERACTIVE)
if mode == XMLaunchMode.LOCAL or mode == XMLaunchMode.INTERACTIVE:
return _local_executor(mode == XMLaunchMode.INTERACTIVE)
raise ValueError(f'Unknown launch mode: {mode}')
17 changes: 13 additions & 4 deletions xmanager/contrib/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@
import os

from absl import app
from xmanager import xm_flags
from absl import flags

_GCS_PATH = flags.DEFINE_string(
'xm_gcs_path',
None,
(
'A GCS directory within a bucket to store output '
'(in gs://bucket/directory format).'
),
)

_GS_PREFIX = 'gs://'
_GCS_PREFIX = '/gcs/'
Expand Down Expand Up @@ -55,17 +64,17 @@ def get_gcs_path_or_fail(project_name: str) -> str:
If the --xm_gcs_path flag is empty, or contains invalid value, raise an
error. Otherwise, returns a flag value.
"""
if not xm_flags.GCS_PATH.value:
if not _GCS_PATH.value:
raise app.UsageError(
'--xm_gcs_path is missing. Suggestion: '
+ f'--xm_gcs_path={suggestion(project_name)}'
)
elif not is_gcs_path(xm_flags.GCS_PATH.value):
elif not is_gcs_path(_GCS_PATH.value):
raise app.UsageError(
'--xm_gcs_path not in gs://bucket/directory or /gcs/path format. '
+ f'Suggestion: --xm_gcs_path={suggestion(project_name)}'
)
return str(xm_flags.GCS_PATH.value)
return str(_GCS_PATH.value)


def is_gs_path(path: str) -> bool:
Expand Down

0 comments on commit 5e19287

Please sign in to comment.