Skip to content

Commit

Permalink
Move contrib flags into xmanager/xm_flags.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615542078
Change-Id: Ibf6138f84b2f173b2f2ed5d797cbbd08138db80e
GitOrigin-RevId: 7360f21cdc116933ba616d2c7b079f2bcbcfce9a
  • Loading branch information
fionalang authored and alpiccioni committed Dec 4, 2024
1 parent 8daf228 commit ec291b3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 44 deletions.
45 changes: 14 additions & 31 deletions xmanager/contrib/executor_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,40 +47,20 @@
executor = executor_fn(**kwargs)
"""

import enum
from typing import Callable, List, Optional, Union

from absl import flags
from xmanager import xm
from xmanager import xm_flags
from xmanager import xm_local


class XMLaunchMode(enum.Enum):
"""Specifies an executor to run an experiment."""

VERTEX = 'vertex'
LOCAL = 'local'
INTERACTIVE = 'interactive'


_XM_LAUNCH_MODE = flags.DEFINE_enum_class(
'xm_launch_mode',
XMLaunchMode.VERTEX,
XMLaunchMode,
'How to launch the experiment. Supports local and interactive execution, '
+ 'launch on '
+
'Vertex.',
)


def launch_mode() -> XMLaunchMode:
return _XM_LAUNCH_MODE.value
def launch_mode() -> xm_flags.XMLaunchMode:
return xm_flags.XM_LAUNCH_MODE.value


def create_experiment(
experiment_title: Optional[str] = None,
mode: Optional[XMLaunchMode] = None,
mode: Optional[xm_flags.XMLaunchMode] = None,
) -> xm.Experiment:
"""Creates an experiment depending on the launch mode.
Expand All @@ -99,9 +79,9 @@ def create_experiment(
mode = launch_mode()

if mode in (
XMLaunchMode.LOCAL,
XMLaunchMode.INTERACTIVE,
XMLaunchMode.VERTEX,
xm_flags.XMLaunchMode.LOCAL,
xm_flags.XMLaunchMode.INTERACTIVE,
xm_flags.XMLaunchMode.VERTEX,
):
# TODO: add import here?
return xm_local.create_experiment(experiment_title)
Expand All @@ -128,7 +108,7 @@ def setup_local(*args, **kwargs_in) -> xm_local.Local:


def get_executor(
mode: Optional[XMLaunchMode] = None,
mode: Optional[xm_flags.XMLaunchMode] = None,
) -> Callable[..., xm.Executor]:
"""Select an `xm.Executor` specialization depending on the launch mode.
Expand All @@ -145,8 +125,11 @@ def get_executor(
if mode is None:
mode = launch_mode()

if mode == XMLaunchMode.VERTEX:
if mode == xm_flags.XMLaunchMode.VERTEX:
return xm_local.Caip
if mode == XMLaunchMode.LOCAL or mode == XMLaunchMode.INTERACTIVE:
return _local_executor(mode == XMLaunchMode.INTERACTIVE)
if (
mode == xm_flags.XMLaunchMode.LOCAL
or mode == xm_flags.XMLaunchMode.INTERACTIVE
):
return _local_executor(mode == xm_flags.XMLaunchMode.INTERACTIVE)
raise ValueError(f'Unknown launch mode: {mode}')
17 changes: 4 additions & 13 deletions xmanager/contrib/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,7 @@
import os

from absl import app
from absl import flags

_GCS_PATH = flags.DEFINE_string(
'xm_gcs_path',
None,
(
'A GCS directory within a bucket to store output '
'(in gs://bucket/directory format).'
),
)
from xmanager import xm_flags

_GS_PREFIX = 'gs://'
_GCS_PREFIX = '/gcs/'
Expand Down Expand Up @@ -64,17 +55,17 @@ def get_gcs_path_or_fail(project_name: str) -> str:
If the --xm_gcs_path flag is empty, or contains invalid value, raise an
error. Otherwise, returns a flag value.
"""
if not _GCS_PATH.value:
if not xm_flags.GCS_PATH.value:
raise app.UsageError(
'--xm_gcs_path is missing. Suggestion: '
+ f'--xm_gcs_path={suggestion(project_name)}'
)
elif not is_gcs_path(_GCS_PATH.value):
elif not is_gcs_path(xm_flags.GCS_PATH.value):
raise app.UsageError(
'--xm_gcs_path not in gs://bucket/directory or /gcs/path format. '
+ f'Suggestion: --xm_gcs_path={suggestion(project_name)}'
)
return str(_GCS_PATH.value)
return str(xm_flags.GCS_PATH.value)


def is_gs_path(path: str) -> bool:
Expand Down
31 changes: 31 additions & 0 deletions xmanager/xm_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""XManager Flags."""

import enum
from absl import flags

# -------------------- xm_local --------------------
Expand Down Expand Up @@ -55,3 +56,33 @@
BAZEL_COMMAND = flags.DEFINE_string(
'xm_bazel_command', 'bazel', 'A command that runs Bazel.',
)

# -------------------- contrib --------------------

GCS_PATH = flags.DEFINE_string(
'xm_gcs_path',
None,
(
'A GCS directory within a bucket to store output '
'(in gs://bucket/directory format).'
),
)


class XMLaunchMode(enum.Enum):
"""Specifies an executor to run an experiment."""

VERTEX = 'vertex'
LOCAL = 'local'
INTERACTIVE = 'interactive'


XM_LAUNCH_MODE = flags.DEFINE_enum_class(
'xm_launch_mode',
XMLaunchMode.VERTEX,
XMLaunchMode,
'How to launch the experiment. Supports local and interactive execution, '
+ 'launch on '
+
'Vertex.',
)

0 comments on commit ec291b3

Please sign in to comment.