Skip to content

Commit

Permalink
Move cloud flags into xmanager/xm_flags.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615590502
Change-Id: I909704fee714d76fcf48ffd619ec2ac193dba720
GitOrigin-RevId: 9881c6789e369ab012364b00f742bce95d3f67a0
  • Loading branch information
fionalang authored and alpiccioni committed Dec 4, 2024
1 parent f225592 commit ad5c133
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 96 deletions.
16 changes: 3 additions & 13 deletions xmanager/cloud/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,16 @@

import functools
import os
from typing import Any, Iterable, Dict
from typing import Any, Dict, Iterable

from absl import flags
from google import auth
from googleapiclient import discovery
from googleapiclient import errors
from xmanager import xm_flags

_DEFAULT_SCOPES = ('https://www.googleapis.com/auth/cloud-platform',)

_GCP_SERVICE_ACCOUNT_NAME = flags.DEFINE_string(
'xm_gcp_service_account_name',
'xmanager',
(
'Specifies the user-managed service account name to be used by XManager'
'Note that user-managed service accounts have the following format: '
'`{service-account-name}@{project-id}.iam.gserviceaccount.com`, so only'
'the part before @ is required'
),
)


def get_project_name() -> str:
"""Gets the Project ID of the GCP Project."""
Expand Down Expand Up @@ -95,7 +85,7 @@ def get_service_account() -> str:
HttpError: if the response was not a 2xx or 403.
"""

service_account_name = _GCP_SERVICE_ACCOUNT_NAME.value
service_account_name = xm_flags.GCP_SERVICE_ACCOUNT_NAME.value
service_account = (
f'{service_account_name}@{get_project_name()}.iam.gserviceaccount.com'
)
Expand Down
43 changes: 5 additions & 38 deletions xmanager/cloud/build_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,15 @@
import tempfile
from typing import Dict, List, Optional

from absl import flags
from docker.utils import utils as docker_utils

from xmanager import xm
from xmanager import xm_flags
from xmanager.cloud import auth
from xmanager.cloud import cloud_build
from xmanager.cloud import docker_lib
from xmanager.docker import docker_adapter
from xmanager.xm import utils

_BUILD_IMAGE_LOCALLY = flags.DEFINE_boolean(
'xm_build_image_locally',
True,
(
'Use local Docker to build images instead of remote Google Cloud Build.'
' This is usually a lot faster but requires docker to be installed.'
),
)
_USE_DOCKER_COMMAND = flags.DEFINE_boolean(
'xm_use_docker_command',
True,
(
'Call "docker build" in a subprocess rather than using Python docker '
'client library when building the docker image locally. This provies a '
'much nicer output for interactive use.'
),
)
_SHOW_DOCKER_COMMAND_PROGRESS = flags.DEFINE_boolean(
'xm_show_docker_command_progress',
False,
'Show container output during the "docker build".',
)
_WRAP_LATE_BINDINGS = flags.DEFINE_boolean(
'xm_wrap_late_bindings',
False,
(
'Feature flag to wrap and unwrap late bindings for network addresses. '
'ONLY works with PythonContainer with default instructions or simple '
'instructions that do not modify the file directory. '
'REQUIRES ./entrypoint.sh to be the ENTRYPOINT.'
),
)

# TODO: Find a master image than is compatible with every
# combination (TF, Torch, JAX) X (CPU, GPU, TPU).
Expand Down Expand Up @@ -126,7 +93,7 @@ def build(
python_path = py_executable.path

with tempfile.TemporaryDirectory() as wrapped_directory:
if _WRAP_LATE_BINDINGS.value:
if xm_flags.WRAP_LATE_BINDINGS.value:
_wrap_late_bindings(wrapped_directory, python_path, dockerfile)
python_path = wrapped_directory
dockerfile = os.path.join(python_path, 'Dockerfile')
Expand Down Expand Up @@ -167,15 +134,15 @@ def build_by_dockerfile(
The name of the built image.
"""
print('Building Docker image, please wait...')
if _BUILD_IMAGE_LOCALLY.value:
if xm_flags.BUILD_IMAGE_LOCALLY.value:
if docker_lib.is_docker_installed():
# TODO: Improve out-of-disk space handling.
return docker_lib.build_docker_image(
image_name,
path,
dockerfile,
use_docker_command=_USE_DOCKER_COMMAND.value,
show_docker_command_progress=_SHOW_DOCKER_COMMAND_PROGRESS.value,
use_docker_command=xm_flags.USE_DOCKER_COMMAND.value,
show_docker_command_progress=xm_flags.SHOW_DOCKER_COMMAND_PROGRESS.value,
)
print('Falling back to CloudBuild. See INFO log for details.')

Expand Down
38 changes: 5 additions & 33 deletions xmanager/cloud/cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,13 @@
from typing import Any, Dict, Optional
import warnings

from absl import flags
from docker.utils import utils as docker_utils
from google.cloud import storage
from googleapiclient import discovery
import termcolor

from xmanager import xm_flags
from xmanager.cloud import auth

_CLOUD_BUILD_TIMEOUT_SECONDS = flags.DEFINE_integer(
'xm_cloud_build_timeout_seconds',
1200,
(
'The amount of time that builds should be allowed to run, '
'to second granularity.'
),
)
_USE_CLOUD_BUILD_CACHE = flags.DEFINE_boolean(
'xm_use_cloud_build_cache',
False,
( # pylint:disable=g-line-too-long
'Use Cloud Build cache to speed up the Docker build. '
'An image with the same name tagged as :latest should exist.'
'More details at'
' https://cloud.google.com/cloud-build/docs/speeding-up-builds#using_a_cached_docker_image'
),
)

_USE_KANIKO = flags.DEFINE_boolean(
'xm_use_kaniko',
False,
'Use kaniko backend for Cloud Build and enable caching.',
)
_KANIKO_CACHE_TTL = flags.DEFINE_string(
'xm_kaniko_cache_ttl', '336h', 'Cache ttl to use for kaniko builds.'
)

_CLOUD_SDK_CREDENTIALS_WARNING = """\
Your application has authenticated using end user credentials from Google \
Expand Down Expand Up @@ -104,16 +76,16 @@ def __init__(
self.bucket = bucket or auth.get_bucket()
self.credentials = credentials or auth.get_creds()
if cloud_build_timeout_seconds is None:
cloud_build_timeout_seconds = _CLOUD_BUILD_TIMEOUT_SECONDS.value
cloud_build_timeout_seconds = xm_flags.CLOUD_BUILD_TIMEOUT_SECONDS.value
self.cloud_build_timeout_seconds = cloud_build_timeout_seconds
if use_cloud_build_cache is None:
use_cloud_build_cache = _USE_CLOUD_BUILD_CACHE.value
use_cloud_build_cache = xm_flags.USE_CLOUD_BUILD_CACHE.value
self.use_cloud_build_cache = use_cloud_build_cache
if use_kaniko is None:
use_kaniko = _USE_KANIKO.value
use_kaniko = xm_flags.USE_KANIKO.value
self.use_kaniko = use_kaniko
if kaniko_cache_ttl is None:
kaniko_cache_ttl = _KANIKO_CACHE_TTL.value
kaniko_cache_ttl = xm_flags.KANIKO_CACHE_TTL.value
self.kaniko_cache_ttl = kaniko_cache_ttl
self.cloudbuild_api = None # discovery CloudBuild v1 client

Expand Down
14 changes: 2 additions & 12 deletions xmanager/cloud/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,18 @@
import functools
from typing import Callable, Dict, List, Optional, Sequence

from absl import flags
import attr
from kubernetes import client as k8s_client
from kubernetes import config as k8s_config
from xmanager import xm
from xmanager import xm_flags
from xmanager.xm import utils
from xmanager.xm_local import executables as local_executables
from xmanager.xm_local import execution as local_execution
from xmanager.xm_local import executors as local_executors
from xmanager.xm_local import status as local_status


_K8S_SERVICE_ACCOUNT_NAME = flags.DEFINE_string(
'xm_k8s_service_account_name',
'default',
(
'Specifies the Kubernetes Service Account name to be used by XManager'
' inthe pod specifications.'
),
)


@functools.lru_cache()
def client():
# Global singleton defers client creation until an actual launch.
Expand Down Expand Up @@ -127,7 +117,7 @@ def launch(
annotations=annotations_from_executor(executor),
),
spec=k8s_client.V1PodSpec(
service_account=_K8S_SERVICE_ACCOUNT_NAME.value,
service_account=xm_flags.K8S_SERVICE_ACCOUNT_NAME.value,
hostname=job_name,
subdomain=service,
restart_policy='Never',
Expand Down
89 changes: 89 additions & 0 deletions xmanager/xm_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,95 @@
import enum
from absl import flags

# --------------------- cloud ---------------------

BUILD_IMAGE_LOCALLY = flags.DEFINE_boolean(
'xm_build_image_locally',
True,
(
'Use local Docker to build images instead of remote Google Cloud Build.'
' This is usually a lot faster but requires docker to be installed.'
),
)

USE_DOCKER_COMMAND = flags.DEFINE_boolean(
'xm_use_docker_command',
True,
(
'Call "docker build" in a subprocess rather than using Python docker '
'client library when building the docker image locally. This provies a '
'much nicer output for interactive use.'
),
)

SHOW_DOCKER_COMMAND_PROGRESS = flags.DEFINE_boolean(
'xm_show_docker_command_progress',
False,
'Show container output during the "docker build".',
)

WRAP_LATE_BINDINGS = flags.DEFINE_boolean(
'xm_wrap_late_bindings',
False,
(
'Feature flag to wrap and unwrap late bindings for network addresses. '
'ONLY works with PythonContainer with default instructions or simple '
'instructions that do not modify the file directory. '
'REQUIRES ./entrypoint.sh to be the ENTRYPOINT.'
),
)

CLOUD_BUILD_TIMEOUT_SECONDS = flags.DEFINE_integer(
'xm_cloud_build_timeout_seconds',
1200,
(
'The amount of time that builds should be allowed to run, '
'to second granularity.'
),
)

USE_CLOUD_BUILD_CACHE = flags.DEFINE_boolean(
'xm_use_cloud_build_cache',
False,
( # pylint:disable=g-line-too-long
'Use Cloud Build cache to speed up the Docker build. '
'An image with the same name tagged as :latest should exist. '
'More details at '
'https://cloud.google.com/cloud-build/docs/speeding-up-builds#using_a_cached_docker_image'
),
)

USE_KANIKO = flags.DEFINE_boolean(
'xm_use_kaniko',
False,
'Use kaniko backend for Cloud Build and enable caching.',
)

KANIKO_CACHE_TTL = flags.DEFINE_string(
'xm_kaniko_cache_ttl', '336h', 'Cache ttl to use for kaniko builds.',
)

GCP_SERVICE_ACCOUNT_NAME = flags.DEFINE_string(
'xm_gcp_service_account_name',
'xmanager',
(
'Specifies the user-managed service account name to be used by'
' XManager. Note that user-managed service accounts have the following'
' format:'
' `{service-account-name}@{project-id}.iam.gserviceaccount.com`, so'
' only the part before @ is required'
),
)

K8S_SERVICE_ACCOUNT_NAME = flags.DEFINE_string(
'xm_k8s_service_account_name',
'default',
(
'Specifies the Kubernetes Service Account name to be used by XManager'
' in the pod specifications.'
),
)

# -------------------- xm_local --------------------

DB_YAML_CONFIG_PATH = flags.DEFINE_string(
Expand Down

0 comments on commit ad5c133

Please sign in to comment.