From 942cf3fd39ca85c4cf4c0705d51d72458735eae7 Mon Sep 17 00:00:00 2001 From: DeepMind Team Date: Tue, 27 Feb 2024 15:18:49 -0800 Subject: [PATCH] Move xm_local flags into xmanager/xm_flags.py. PiperOrigin-RevId: 610891136 Change-Id: Ia93693077ce121eb5ba9b78448deae9ea6c7f4dd GitOrigin-RevId: 5cee75d2d1d7799af02a8e92c6a8f9c5be7df045 --- xm_flags.py | 41 --------------------- xmanager/xm_local/packaging/bazel_tools.py | 12 ++++-- xmanager/xm_local/storage/database.py | 43 ++++++++++++++++++++-- 3 files changed, 47 insertions(+), 49 deletions(-) diff --git a/xm_flags.py b/xm_flags.py index a10904d..fe05d7d 100644 --- a/xm_flags.py +++ b/xm_flags.py @@ -14,44 +14,3 @@ """XManager Flags.""" from absl import flags - -# -------------------- xm_local -------------------- - -DB_YAML_CONFIG_PATH = flags.DEFINE_string( - 'xm_db_yaml_config_path', - None, - """ - Path of YAML config file containing DB connection details. - - A valid config file contains two main entries: - `sql_connector`: must be one of [`sqlite`, `generic`, `cloudsql`] - - `sql_connection_settings`: contains details about the connection URL. - These match the interface of `SqlConnectionSettings` and their - combination must form a valid `sqlalchemy` connection URL. Possible - fields are: - - backend, e.g. 'mysql', 'postgresql' - - db_name - - driver, e.g. 'pymysql', 'pg8000' - - username - - password - - host (instance connection name when using CloudSql) - - port - """, -) - -UPGRADE_DB = flags.DEFINE_boolean( - 'xm_upgrade_db', - False, - """ - Specifies if XManager should update the database to the latest version. - It's recommended to take a back-up of the database before updating, since - migrations can fail/have errors. This is especially true - for non-transactional DDLs, where partial migrations can occur on - failure, leaving the database in a not well-defined state. - """, -) - -BAZEL_COMMAND = flags.DEFINE_string( - 'xm_bazel_command', 'bazel', 'A command that runs Bazel.' -) diff --git a/xmanager/xm_local/packaging/bazel_tools.py b/xmanager/xm_local/packaging/bazel_tools.py index f55f9b3..19adefa 100644 --- a/xmanager/xm_local/packaging/bazel_tools.py +++ b/xmanager/xm_local/packaging/bazel_tools.py @@ -20,14 +20,18 @@ import subprocess from typing import Dict, List, Optional, Sequence, Tuple +from absl import flags from xmanager import xm -from xmanager import xm_flags from xmanager.bazel import client from xmanager.bazel import file_utils from google.protobuf.internal.decoder import _DecodeVarint32 from xmanager.generated import build_event_stream_pb2 as bes_pb2 +_BAZEL_COMMAND = flags.DEFINE_string( + 'xm_bazel_command', 'bazel', 'A command that runs Bazel.' +) + def _get_important_outputs( events: Sequence[bes_pb2.BuildEvent], labels: Sequence[str] @@ -99,7 +103,7 @@ def _root_absolute_path() -> str: return ( os.getenv('BUILD_WORKSPACE_DIRECTORY') or subprocess.run( - [xm_flags.BAZEL_COMMAND.value, 'info', 'workspace'], + [_BAZEL_COMMAND.value, 'info', 'workspace'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -126,7 +130,7 @@ def _build_multiple_targets( with file_utils.TemporaryFilePath() as bep_path: subprocess.run( [ - xm_flags.BAZEL_COMMAND.value, + _BAZEL_COMMAND.value, 'build', f'--build_event_binary_file={bep_path}', # Forces a GC at the end of the build and publishes value to BEP. @@ -200,7 +204,7 @@ def fetch_kinds(self, labels: Sequence[str]) -> List[str]: # https://docs.bazel.build/versions/main/query.html#output-label_kind. stdout = subprocess.run( [ - xm_flags.BAZEL_COMMAND.value, + _BAZEL_COMMAND.value, 'query', f"'{' union '.join(labels)}'", '--output', diff --git a/xmanager/xm_local/storage/database.py b/xmanager/xm_local/storage/database.py index cdc673d..a1a2d05 100644 --- a/xmanager/xm_local/storage/database.py +++ b/xmanager/xm_local/storage/database.py @@ -18,12 +18,12 @@ import tempfile from typing import Any, Dict, List, Optional, Type, TypeVar +from absl import flags import alembic from alembic.config import Config import attr import sqlalchemy from xmanager import xm -from xmanager import xm_flags from xmanager.generated import data_pb2 import yaml @@ -31,6 +31,41 @@ from google.cloud.sql.connector import Connector, IPTypes +_DB_YAML_CONFIG_PATH = flags.DEFINE_string( + 'xm_db_yaml_config_path', + None, + """ + Path of YAML config file containing DB connection details. + + A valid config file contains two main entries: + `sql_connector`: must be one of [`sqlite`, `generic`, `cloudsql`] + + `sql_connection_settings`: contains details about the connection URL. + These match the interface of `SqlConnectionSettings` and their + combination must form a valid `sqlalchemy` connection URL. Possible + fields are: + - backend, e.g. 'mysql', 'postgresql' + - db_name + - driver, e.g. 'pymysql', 'pg8000' + - username + - password + - host (instance connection name when using CloudSql) + - port + """, +) + +_UPGRADE_DB = flags.DEFINE_boolean( + 'xm_upgrade_db', + False, + """ + Specifies if XManager should update the database to the latest version. + It's recommended to take a back-up of the database before updating, since + migrations can fail/have errors. This is especially true + for non-transactional DDLs, where partial migrations can occur on + failure, leaving the database in a not well-defined state. + """, +) + @attr.s(auto_attribs=True) class WorkUnitResult: @@ -221,7 +256,7 @@ def maybe_migrate_database_version(self): need_to_update = ( db_version != self.latest_version_available() and db_version ) or legacy_sqlite_db - if need_to_update and not xm_flags.UPGRADE_DB.value: + if need_to_update and not _UPGRADE_DB.value: raise RuntimeError( f'Database is not up to date: current={self.database_version()}, ' f'latest={self.latest_version_available()}. Take a backup of the ' @@ -369,9 +404,9 @@ def _validate_db_config(config: Dict[str, Any]) -> None: @functools.lru_cache() def _db_config() -> Dict[str, Any]: """Parses and validates YAML DB config file to a dict.""" - if xm_flags.DB_YAML_CONFIG_PATH.value is not None: + if _DB_YAML_CONFIG_PATH.value is not None: db_config_file = xm.utils.resolve_path_relative_to_launcher( - xm_flags.DB_YAML_CONFIG_PATH.value + _DB_YAML_CONFIG_PATH.value ) with open(db_config_file, 'r') as f: config = yaml.safe_load(f)