From 3901d2234dc8f64cec3db6d72873c81f82193302 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke Date: Tue, 19 Sep 2023 15:50:07 +0100 Subject: [PATCH] use mutex to enforce mutually exclusive options --- core/dbt/cli/flags.py | 25 ++----------------------- core/dbt/cli/main.py | 4 ++-- core/dbt/cli/params.py | 34 +++++++++++++++++++++++++++++++++- 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index 863db6ed0e4..3a42aabd715 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -8,13 +8,14 @@ from click import Context, get_current_context, Parameter from click.core import Command as ClickCommand, Group, ParameterSource + from dbt.cli.exceptions import DbtUsageException from dbt.cli.resolvers import default_log_path, default_project_dir from dbt.cli.types import Command as CliCommand from dbt.config.profile import read_user_config from dbt.contracts.project import UserConfig -from dbt.exceptions import DbtInternalError from dbt.deprecations import renamed_env_var +from dbt.exceptions import DbtInternalError from dbt.helper_types import WarnErrorOptions if os.name != "nt": @@ -243,11 +244,6 @@ def _assign_params( if os.getenv("DO_NOT_TRACK", "").lower() in ("1", "t", "true", "y", "yes"): object.__setattr__(self, "SEND_ANONYMOUS_USAGE_STATS", False) - # Check mutual exclusivity once all flags are set. - self._assert_mutually_exclusive( - params_assigned_from_default, ["WARN_ERROR", "WARN_ERROR_OPTIONS"] - ) - # Support lower cased access for legacy code. params = set( x for x in dir(self) if not callable(getattr(self, x)) and not x.startswith("__") @@ -263,23 +259,6 @@ def _override_if_set(self, lead: str, follow: str, defaulted: Set[str]) -> None: if lead.lower() not in defaulted and follow.lower() in defaulted: object.__setattr__(self, follow.upper(), getattr(self, lead.upper(), None)) - def _assert_mutually_exclusive( - self, params_assigned_from_default: Set[str], group: List[str] - ) -> None: - """ - Ensure no elements from group are simultaneously provided by a user, as inferred from params_assigned_from_default. - Raises click.UsageError if any two elements from group are simultaneously provided by a user. - """ - set_flag = None - for flag in group: - flag_set_by_user = flag.lower() not in params_assigned_from_default - if flag_set_by_user and set_flag: - raise DbtUsageException( - f"{flag.lower()}: not allowed with argument {set_flag.lower()}" - ) - elif flag_set_by_user: - set_flag = flag - def fire_deprecations(self): """Fires events for deprecated env_var usage.""" [dep_fn() for dep_fn in self.deprecated_env_var_warnings] diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index ab501e015f4..0346c5a34a8 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -149,6 +149,8 @@ def global_flags(func): @p.use_experimental_parser @p.version @p.version_check + @p.warn_error + @p.warn_error_options @p.write_json @functools.wraps(func) def wrapper(*args, **kwargs): @@ -166,8 +168,6 @@ def wrapper(*args, **kwargs): ) @click.pass_context @global_flags -@p.warn_error -@p.warn_error_options @p.log_format def cli(ctx, **kwargs): """An ELT tool for managing your SQL transformations and data models. diff --git a/core/dbt/cli/params.py b/core/dbt/cli/params.py index b8231058531..78140bb9c42 100644 --- a/core/dbt/cli/params.py +++ b/core/dbt/cli/params.py @@ -1,11 +1,39 @@ from pathlib import Path import click -from dbt.cli.options import MultiOption + +from dbt.cli.exceptions import DbtUsageException from dbt.cli.option_types import YAML, ChoiceTuple, WarnErrorOptionsType +from dbt.cli.options import MultiOption from dbt.cli.resolvers import default_project_dir, default_profiles_dir from dbt.version import get_version_information + +# Copied from https://github.com/pallets/click/issues/257#issuecomment-403312784 +class Mutex(click.Option): + def __init__(self, *args, **kwargs): + self.not_required_if: list = kwargs.pop("not_required_if") + + assert self.not_required_if, "'not_required_if' parameter required" + kwargs["help"] = ( + kwargs.get("help", "") + + "Option is mutually exclusive with " + + ", ".join(self.not_required_if) + + "." + ).strip() + super(Mutex, self).__init__(*args, **kwargs) + + def handle_parse_result(self, ctx, opts, args): + current_opt: bool = self.name in opts + for mutex_opt in self.not_required_if: + if mutex_opt in opts: + if current_opt: + raise DbtUsageException(f"{self.name}: not allowed with argument {mutex_opt}") + else: + self.prompt = None + return super(Mutex, self).handle_parse_result(ctx, opts, args) + + args = click.option( "--args", envvar=None, @@ -589,6 +617,8 @@ def _version_callback(ctx, _param, value): help="If dbt would normally warn, instead raise an exception. Examples include --select that selects nothing, deprecations, configurations with no associated models, invalid test configurations, and missing sources/refs in tests.", default=None, is_flag=True, + cls=Mutex, + not_required_if=["warn_error_options"], ) warn_error_options = click.option( @@ -598,6 +628,8 @@ def _version_callback(ctx, _param, value): help="""If dbt would normally warn, instead raise an exception based on include/exclude configuration. Examples include --select that selects nothing, deprecations, configurations with no associated models, invalid test configurations, and missing sources/refs in tests. This argument should be a YAML string, with keys 'include' or 'exclude'. eg. '{"include": "all", "exclude": ["NoNodesForSelectionCriteria"]}'""", type=WarnErrorOptionsType(), + cls=Mutex, + not_required_if=["warn_error"], ) write_json = click.option(