From 74b1a38de317b6da9e598f60962c96c64a7dd817 Mon Sep 17 00:00:00 2001 From: Adam Stus Date: Mon, 4 Mar 2024 13:37:45 +0100 Subject: [PATCH 01/16] Revert "Bump tomlkit from 0.12.3 to 0.12.4 (#848)" (#863) This reverts commit 80467b3d2b50c7996cfd068ed510476e64a714e0. mraba/app-factory: create app with factory (#860) * mraba/app-factory: create app with factory SNOW-1043081: Adding support for qualified names for image repositories. (#823) * SNOW-1043081: Adding support for qualified image repository names * SNOW-1043081: Fixing test imports * SNOW-1043081: Adding tests for getting image repository url without db or schema SNOW-1011771: Added generic REPLACE, IF EXISTS, IF NOT EXISTS flags (#826) * SNOW-1011771: Adding generic OR REPLACE, IF EXISTS, IF NOT EXISTS flags to flags.py * SNOW-1011771: Using generic ReplaceOption in snowpark deploy and streamlit deploy * SNOW-1011771: Using generic IfNotExistsOption in compute pool create and updating unit tests. * SNOW-1011771: Using generic IfNotExistsOption in service create and updating unit tests * SNOW-1011771: Using generic ReplaceOption and IfNotExistsOption in image-repository create. * SNOW-1011771: Fixup * SNOW-1011771: Update release notes * SNOW-1011771: Update test_help_messages * SNOW-1011771: precommit * SNOW-1011771: Adding validation that only one create mode option can be set at once * fixup * SNOW-1011771: Updating tests for REPLACE AND IF NOT EXISTS case on image-repository create to throw error * SNOW-1011771: Adding snapshots * SNOW-1011771: Adding a new mutually_exclusive field to OverrideableOption * formatting * SNOW-1011771: Adding tests for OverrideableOption * SNOW-1011771: Fixing test failures due to improperly quoted string Add snow --help to test_help_messages (#821) * Add snow --help to test_help_messages * update snapshot Avoid plain print, make sure silent is eager flag (#871) [NADE] Update CODEOWNERS to use NADE team id. (#873) update to using nade team in codeowners New workflow to stop running workflows after new commit (#862) * new workflow * new workflow * new workflow * new workflow * typo fix * typo fix * import fix * import fix * import fix * import fix * import fix * import fix * import fix * new approach * new approach * new approach * new approach * new approach * New approach * added to test * Added to more workflows * Dummy commit Schemas adjusting native apps to streamlit fixing streamlit fixies after unit tests fixies after unit tests fixing for snowflake fixing for snowflake Fixes after review Fixes after review Fixes after review --- .github/CODEOWNERS | 10 +- .github/workflows/build.yaml | 4 + .github/workflows/e2e_test.yaml | 4 + .github/workflows/integration_test.yaml | 4 + .github/workflows/performance_test.yaml | 4 + .pre-commit-config.yaml | 29 +++ RELEASE-NOTES.md | 3 + pyproject.toml | 3 +- src/snowflake/cli/api/commands/flags.py | 161 ++++++++++++--- src/snowflake/cli/api/commands/snow_typer.py | 8 +- src/snowflake/cli/api/project/definition.py | 44 ++-- .../cli/api/project/definition_manager.py | 3 +- .../cli/api/project/schemas/native_app.py | 43 ---- .../project/schemas/native_app/__init__.py | 0 .../project/schemas/native_app/application.py | 31 +++ .../project/schemas/native_app/native_app.py | 39 ++++ .../api/project/schemas/native_app/package.py | 40 ++++ .../project/schemas/native_app/path_maping.py | 10 + .../api/project/schemas/project_definition.py | 40 ++-- .../cli/api/project/schemas/relaxed_map.py | 44 ---- .../cli/api/project/schemas/snowpark.py | 47 ----- .../api/project/schemas/snowpark/__init__.py | 0 .../api/project/schemas/snowpark/argument.py | 12 ++ .../api/project/schemas/snowpark/callable.py | 66 ++++++ .../api/project/schemas/snowpark/snowpark.py | 22 ++ .../cli/api/project/schemas/streamlit.py | 20 -- .../api/project/schemas/streamlit/__init__.py | 0 .../project/schemas/streamlit/streamlit.py | 28 +++ .../api/project/schemas/updatable_model.py | 20 ++ src/snowflake/cli/api/project/util.py | 1 + src/snowflake/cli/api/secure_path.py | 2 +- src/snowflake/cli/api/sql_execution.py | 101 +++++---- src/snowflake/cli/api/utils/naming_utils.py | 27 +++ src/snowflake/cli/app/__main__.py | 3 +- src/snowflake/cli/app/cli_app.py | 190 ++++++++--------- src/snowflake/cli/app/loggers.py | 2 +- src/snowflake/cli/app/main_typer.py | 4 +- .../cli/plugins/connection/commands.py | 4 +- .../cli/plugins/nativeapp/artifacts.py | 5 +- .../cli/plugins/nativeapp/manager.py | 74 ++++--- .../cli/plugins/nativeapp/run_processor.py | 5 +- .../nativeapp/version/version_processor.py | 3 +- src/snowflake/cli/plugins/object/common.py | 11 +- .../cli/plugins/snowpark/commands.py | 76 +++---- src/snowflake/cli/plugins/snowpark/common.py | 23 ++- src/snowflake/cli/plugins/spcs/common.py | 6 +- .../cli/plugins/spcs/compute_pool/commands.py | 6 +- .../cli/plugins/spcs/compute_pool/manager.py | 6 +- .../plugins/spcs/image_repository/commands.py | 14 +- .../plugins/spcs/image_repository/manager.py | 38 ++-- .../cli/plugins/spcs/services/commands.py | 6 +- .../cli/plugins/spcs/services/manager.py | 7 +- .../cli/plugins/streamlit/commands.py | 26 +-- tests/__snapshots__/test_help_messages.ambr | 44 +++- tests/__snapshots__/test_snow_connector.ambr | 4 +- .../commands/__snapshots__/test_flags.ambr | 28 +++ tests/api/commands/test_flags.py | 192 +++++++++++++++++- tests/api/commands/test_snow_typer.py | 6 +- tests/api/utils/test_naming_utils.py | 15 ++ tests/conftest.py | 5 +- tests/nativeapp/test_artifacts.py | 6 +- tests/nativeapp/test_manager.py | 2 +- tests/nativeapp/test_package_scripts.py | 2 +- tests/nativeapp/test_run_processor.py | 4 +- tests/nativeapp/test_teardown_processor.py | 2 +- .../test_version_create_processor.py | 2 +- .../nativeapp/test_version_drop_processor.py | 2 +- tests/project/test_config.py | 52 ++--- tests/project/test_pydantic_schemas.py | 0 tests/snowpark/test_function.py | 10 +- tests/snowpark/test_procedure.py | 4 +- .../__snapshots__/test_image_repository.ambr | 19 ++ tests/spcs/test_compute_pool.py | 42 +++- tests/spcs/test_image_repository.py | 97 +++++---- tests/spcs/test_jobs.py | 5 + tests/spcs/test_services.py | 44 +++- tests/streamlit/test_config.py | 2 +- tests/test_help_messages.py | 1 + tests/test_sql.py | 79 +++++++ tests/testing_utils/fixtures.py | 8 +- tests_integration/conftest.py | 5 +- tests_integration/test_external_plugins.py | 42 ++-- 82 files changed, 1468 insertions(+), 635 deletions(-) delete mode 100644 src/snowflake/cli/api/project/schemas/native_app.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/__init__.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/application.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/native_app.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/package.py create mode 100644 src/snowflake/cli/api/project/schemas/native_app/path_maping.py delete mode 100644 src/snowflake/cli/api/project/schemas/relaxed_map.py delete mode 100644 src/snowflake/cli/api/project/schemas/snowpark.py create mode 100644 src/snowflake/cli/api/project/schemas/snowpark/__init__.py create mode 100644 src/snowflake/cli/api/project/schemas/snowpark/argument.py create mode 100644 src/snowflake/cli/api/project/schemas/snowpark/callable.py create mode 100644 src/snowflake/cli/api/project/schemas/snowpark/snowpark.py delete mode 100644 src/snowflake/cli/api/project/schemas/streamlit.py create mode 100644 src/snowflake/cli/api/project/schemas/streamlit/__init__.py create mode 100644 src/snowflake/cli/api/project/schemas/streamlit/streamlit.py create mode 100644 src/snowflake/cli/api/project/schemas/updatable_model.py create mode 100644 src/snowflake/cli/api/utils/naming_utils.py create mode 100644 tests/api/commands/__snapshots__/test_flags.ambr create mode 100644 tests/api/utils/test_naming_utils.py create mode 100644 tests/project/test_pydantic_schemas.py create mode 100644 tests/spcs/__snapshots__/test_image_repository.ambr diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f383ceefa2..78117f4c78 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,10 +1,10 @@ * @sfc-gh-turbaszek @sfc-gh-pjob @sfc-gh-jsikorski @sfc-gh-astus @sfc-gh-mraba @sfc-gh-pczajka # Native Apps Owners -src/snowflake/cli/plugins/nativeapp/ @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi -tests/nativeapp/ @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi -tests_integration/test_nativeapp.py @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi +src/snowflake/cli/plugins/nativeapp/ @snowflakedb/nade +tests/nativeapp/ @sfc-gh-bgoel @snowflakedb/nade +tests_integration/test_nativeapp.py @snowflakedb/nade # Project Definition Owners -src/snowflake/cli/api/project/schemas/native_app.py @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi -tests/project/ @sfc-gh-turbaszek @sfc-gh-pjob @sfc-gh-jsikorski @sfc-gh-astus @sfc-gh-mraba @sfc-gh-pczajka @sfc-gh-bgoel @sfc-gh-cgorrie @sfc-gh-bdufour @sfc-gh-melnacouzi +src/snowflake/cli/api/project/schemas/native_app.py @snowflakedb/nade +tests/project/ @sfc-gh-turbaszek @sfc-gh-pjob @sfc-gh-jsikorski @sfc-gh-astus @sfc-gh-mraba @sfc-gh-pczajka @snowflakedb/nade diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 69448b37ab..5a02787584 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -10,6 +10,10 @@ on: branches: - main +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: lint: diff --git a/.github/workflows/e2e_test.yaml b/.github/workflows/e2e_test.yaml index 0d86a21010..63fe83b63d 100644 --- a/.github/workflows/e2e_test.yaml +++ b/.github/workflows/e2e_test.yaml @@ -10,6 +10,10 @@ on: branches: - main +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: test: runs-on: ubuntu-latest diff --git a/.github/workflows/integration_test.yaml b/.github/workflows/integration_test.yaml index 346e5e87a0..06a697e3d3 100644 --- a/.github/workflows/integration_test.yaml +++ b/.github/workflows/integration_test.yaml @@ -10,6 +10,10 @@ on: branches: - main +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: test: runs-on: ${{ matrix.os }} diff --git a/.github/workflows/performance_test.yaml b/.github/workflows/performance_test.yaml index dfe48fdeb9..5b7e85f446 100644 --- a/.github/workflows/performance_test.yaml +++ b/.github/workflows/performance_test.yaml @@ -8,6 +8,10 @@ on: branches: - main +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: test: runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6fe4066565..eda7e65d96 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,3 +25,32 @@ repos: hooks: - id: mypy additional_dependencies: [types-all] + - repo: local + hooks: + - id: check-print-in-code + language: pygrep + name: "Check for print statements" + entry: "print\\(|echo\\(" + pass_filenames: true + files: ^src/snowflake/.*\.py$ + exclude: > + (?x) + ^src/snowflake/cli/api/console/.*$| + ^src/snowflake/cli/app/printing.py$| + ^src/snowflake/cli/app/dev/.*$| + ^src/snowflake/cli/templates/.*$| + ^src/snowflake/cli/api/utils/rendering.py$| + ^src/snowflake/cli/plugins/spcs/common.py$ + - id: check-app-imports-in-api + language: pygrep + name: "No top level cli.app imports in cli.api" + entry: "^from snowflake\\.cli\\.app" + pass_filenames: true + files: ^src/snowflake/cli/api/.*\.py$ + - id: avoid-snowcli + language: pygrep + name: "Prefer snowflake CLI over snowcli" + entry: "snowcli" + pass_filenames: true + files: ^src/.*\.py$ + exclude: ^src/snowflake/cli/app/telemetry.py$ diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index f2119d22fa..dc9d351f31 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -4,10 +4,13 @@ ## New additions * Added support for fully qualified name (`database.schema.name`) in `name` parameter in streamlit project definition +* Added support for fully qualified image repository names in `spcs image-repository` commands. +* Added `--if-not-exists` option to `create` commands for `service`, and `compute-pool`. Added `--replace` and `--if-not-exists` options for `image-repository create`. ## Fixes and improvements * Adding `--image-name` option for image name argument in `spcs image-repository list-tags` for consistency with other commands. * Fixed errors during `spcs image-registry login` not being formatted correctly. +* Project definition no longer accept extra fields. Any extra field will cause an error. # v2.1.0 diff --git a/pyproject.toml b/pyproject.toml index b90d636551..00af57691e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,11 @@ dependencies = [ "setuptools==69.1.1", "snowflake-connector-python[secure-local-storage]==3.7.1", "strictyaml==1.7.3", - "tomlkit==0.12.4", + "tomlkit==0.12.3", "typer==0.9.0", "urllib3>=1.21.1,<2.3", "GitPython==3.1.42", + "pydantic==2.6.3" ] classifiers = [ "Development Status :: 3 - Alpha", diff --git a/src/snowflake/cli/api/commands/flags.py b/src/snowflake/cli/api/commands/flags.py index df241f3e69..693c613bbd 100644 --- a/src/snowflake/cli/api/commands/flags.py +++ b/src/snowflake/cli/api/commands/flags.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import Any, Callable, Optional +from inspect import signature +from typing import Any, Callable, List, Optional, Tuple import click import typer +from click import ClickException from snowflake.cli.api.cli_global_context import cli_context_manager +from snowflake.cli.api.console import cli_console from snowflake.cli.api.output.formats import OutputFormat DEFAULT_CONTEXT_SETTINGS = {"help_option_names": ["--help", "-h"]} @@ -13,6 +16,103 @@ _CLI_BEHAVIOUR = "Global configuration" +class OverrideableOption: + """ + Class that allows you to generate instances of typer.models.OptionInfo with some default properties while allowing + specific values to be overriden. + + Custom parameters: + - mutually_exclusive (Tuple[str]|List[str]): A list of parameter names that this Option is not compatible with. If this Option has + a truthy value and any of the other parameters in the mutually_exclusive list has a truthy value, a + ClickException will be thrown. Note that mutually_exclusive can contain an option's own name but does not require + it. + """ + + def __init__( + self, + default: Any, + *param_decls: str, + mutually_exclusive: Optional[List[str] | Tuple[str]] = None, + **kwargs, + ): + self.default = default + self.param_decls = param_decls + self.mutually_exclusive = mutually_exclusive + self.kwargs = kwargs + + def __call__(self, **kwargs) -> typer.models.OptionInfo: + """ + Returns a typer.models.OptionInfo instance initialized with the specified default values along with any overrides + from kwargs. Note that if you are overriding param_decls, you must pass an iterable of strings, you cannot use + positional arguments like you can with typer.Option. Does not modify the original instance. + """ + default = kwargs.get("default", self.default) + param_decls = kwargs.get("param_decls", self.param_decls) + mutually_exclusive = kwargs.get("mutually_exclusive", self.mutually_exclusive) + if not isinstance(param_decls, list) and not isinstance(param_decls, tuple): + raise TypeError("param_decls must be a list or tuple") + passed_kwargs = self.kwargs.copy() + passed_kwargs.update(kwargs) + if passed_kwargs.get("callback", None) or mutually_exclusive: + passed_kwargs["callback"] = self._callback_factory( + passed_kwargs.get("callback", None), mutually_exclusive + ) + for non_kwarg in ["default", "param_decls", "mutually_exclusive"]: + passed_kwargs.pop(non_kwarg, None) + return typer.Option(default, *param_decls, **passed_kwargs) + + class InvalidCallbackSignature(ClickException): + def __init__(self, callback): + super().__init__( + f"Signature {signature(callback)} is not valid for an OverrideableOption callback function. Must have at most one parameter with each of the following types: (typer.Context, typer.CallbackParam, Any Other Type)" + ) + + def _callback_factory( + self, callback, mutually_exclusive: Optional[List[str] | Tuple[str]] + ): + callback = callback if callback else lambda x: x + + # inspect existing_callback to make sure signature is valid + existing_params = signature(callback).parameters + # at most one parameter with each type in [typer.Context, typer.CallbackParam, any other type] + limits = [ + lambda x: x == typer.Context, + lambda x: x == typer.CallbackParam, + lambda x: x != typer.Context and x != typer.CallbackParam, + ] + for limit in limits: + if len([v for v in existing_params.values() if limit(v.annotation)]) > 1: + raise self.InvalidCallbackSignature(callback) + + def generated_callback(ctx: typer.Context, param: typer.CallbackParam, value): + if mutually_exclusive: + for name in mutually_exclusive: + if value and ctx.params.get( + name, False + ): # if the current parameter is set to True and a previous parameter is also Truthy + curr_opt = param.opts[0] + other_opt = [x for x in ctx.command.params if x.name == name][ + 0 + ].opts[0] + raise click.ClickException( + f"Options '{curr_opt}' and '{other_opt}' are incompatible." + ) + + # pass args to existing callback based on its signature (this is how Typer infers callback args) + passed_params = {} + for existing_param in existing_params: + annotation = existing_params[existing_param].annotation + if annotation == typer.Context: + passed_params[existing_param] = ctx + elif annotation == typer.CallbackParam: + passed_params[existing_param] = param + else: + passed_params[existing_param] = value + return callback(**passed_params) + + return generated_callback + + def _callback(provide_setter: Callable[[], Callable[[Any], Any]]): def callback(value): set_value = provide_setter() @@ -73,7 +173,7 @@ def callback(value): def _password_callback(value: str): if value: - click.echo(PLAIN_PASSWORD_MSG) + cli_console.message(PLAIN_PASSWORD_MSG) return _callback(lambda: cli_context_manager.connection_context.set_password)(value) @@ -181,6 +281,7 @@ def _password_callback(value: str): callback=_callback(lambda: cli_context_manager.set_silent), is_flag=True, rich_help_panel=_CLI_BEHAVIOUR, + is_eager=True, ) VerboseOption = typer.Option( @@ -209,6 +310,32 @@ def _password_callback(value: str): help='Regular expression for filtering objects by name. For example, `list --like "my%"` lists all objects that begin with “my”.', ) +# If IfExistsOption, IfNotExistsOption, or ReplaceOption are used with names other than those in CREATE_MODE_OPTION_NAMES, +# you must also override mutually_exclusive if you want to retain the validation that at most one of these flags is +# passed. +CREATE_MODE_OPTION_NAMES = ["if_exists", "if_not_exists", "replace"] + +IfExistsOption = OverrideableOption( + False, + "--if-exists", + help="Only apply this operation if the specified object exists.", + mutually_exclusive=CREATE_MODE_OPTION_NAMES, +) + +IfNotExistsOption = OverrideableOption( + False, + "--if-not-exists", + help="Only apply this operation if the specified object does not already exist.", + mutually_exclusive=CREATE_MODE_OPTION_NAMES, +) + +ReplaceOption = OverrideableOption( + False, + "--replace", + help="Replace this object if it already exists.", + mutually_exclusive=CREATE_MODE_OPTION_NAMES, +) + def experimental_option( experimental_behaviour_description: Optional[str] = None, @@ -247,7 +374,7 @@ def project_definition_option(project_name: str): def _callback(project_path: Optional[str]): dm = DefinitionManager(project_path) - project_definition = dm.project_definition.get(project_name) + project_definition = getattr(dm.project_definition, project_name, None) project_root = dm.project_root if not project_definition: @@ -268,31 +395,3 @@ def _callback(project_path: Optional[str]): callback=_callback, show_default=False, ) - - -class OverrideableOption: - """ - Class that allows you to generate instances of typer.models.OptionInfo with some default properties while allowing specific values to be overriden. - """ - - def __init__(self, default: Any, *param_decls: str, **kwargs): - self.default = default - self.param_decls = param_decls - self.kwargs = kwargs - - def __call__(self, **kwargs) -> typer.models.OptionInfo: - """ - Returns a typer.models.OptionInfo instance initialized with the specified default values along with any overrides - from kwargs.Note that if you are overriding param_decls, - you must pass an iterable of strings, you cannot use positional arguments like you can with typer.Option. - Does not modify the original instance. - """ - default = kwargs.get("default", self.default) - param_decls = kwargs.get("param_decls", self.param_decls) - if not isinstance(param_decls, list) and not isinstance(param_decls, tuple): - raise TypeError("param_decls must be a list or tuple") - passed_kwargs = self.kwargs.copy() - passed_kwargs.update(kwargs) - passed_kwargs.pop("default", None) - passed_kwargs.pop("param_decls", None) - return typer.Option(default, *param_decls, **passed_kwargs) diff --git a/src/snowflake/cli/api/commands/snow_typer.py b/src/snowflake/cli/api/commands/snow_typer.py index 1cfc5596ad..f0cf14afb1 100644 --- a/src/snowflake/cli/api/commands/snow_typer.py +++ b/src/snowflake/cli/api/commands/snow_typer.py @@ -12,8 +12,6 @@ from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS from snowflake.cli.api.exceptions import CommandReturnTypeError from snowflake.cli.api.output.types import CommandResult -from snowflake.cli.app.printing import print_result -from snowflake.cli.app.telemetry import flush_telemetry, log_command_usage log = logging.getLogger(__name__) @@ -73,12 +71,16 @@ def pre_execute(): Pay attention to make this method safe to use if performed operations are not necessary for executing the command in proper way. """ + from snowflake.cli.app.telemetry import log_command_usage + log.debug("Executing command pre execution callback") log_command_usage() @staticmethod def process_result(result): """Command result processor""" + from snowflake.cli.app.printing import print_result + # Because we still have commands like "logs" that do not return anything. # We should improve it in future. if not result: @@ -100,5 +102,7 @@ def post_execute(): Callback executed after running any command callable. Pay attention to make this method safe to use if performed operations are not necessary for executing the command in proper way. """ + from snowflake.cli.app.telemetry import flush_telemetry + log.debug("Executing command post execution callback") flush_telemetry() diff --git a/src/snowflake/cli/api/project/definition.py b/src/snowflake/cli/api/project/definition.py index ae5477b6f9..f6b5373d33 100644 --- a/src/snowflake/cli/api/project/definition.py +++ b/src/snowflake/cli/api/project/definition.py @@ -1,12 +1,10 @@ from pathlib import Path -from typing import Dict, List, Union +from typing import Dict, List +import yaml.loader from snowflake.cli.api.cli_global_context import cli_context from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB -from snowflake.cli.api.project.schemas.project_definition import ( - project_override_schema, - project_schema, -) +from snowflake.cli.api.project.schemas.project_definition import ProjectDefinition from snowflake.cli.api.project.util import ( append_to_identifier, clean_identifier, @@ -14,32 +12,25 @@ to_identifier, ) from snowflake.cli.api.secure_path import SecurePath -from strictyaml import ( - YAML, - as_document, - load, -) +from yaml import load DEFAULT_USERNAME = "unknown_user" -def merge_left(target: Union[Dict, YAML], source: Union[Dict, YAML]) -> None: +def merge_left(target: Dict, source: Dict) -> None: """ Recursively merges key/value pairs from source into target. Modifies the original dict-like "target". """ for k, v in source.items(): - if k in target and ( - isinstance(v, dict) or (isinstance(v, YAML) and v.is_mapping()) - ): + if k in target and isinstance(target[k], dict): # assumption: all inputs have been validated. - assert isinstance(target[k], dict) or isinstance(target[k], YAML) merge_left(target[k], v) else: target[k] = v -def load_project_definition(paths: List[Path]) -> dict: +def load_project_definition(paths: List[Path]) -> ProjectDefinition: """ Loads project definition, optionally overriding values. Definition values are merged in left-to-right order (increasing precedence). @@ -49,22 +40,23 @@ def load_project_definition(paths: List[Path]) -> dict: raise ValueError("Need at least one definition file.") with spaths[0].open("r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB) as base_yml: - definition = load(base_yml.read(), project_schema) + definition = load(base_yml.read(), Loader=yaml.loader.BaseLoader) for override_path in spaths[1:]: with override_path.open( "r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB ) as override_yml: - overrides = load(override_yml.read(), project_override_schema) + overrides = load(override_yml.read(), Loader=yaml.loader.BaseLoader) merge_left(definition, overrides) # TODO: how to show good error messages here? - definition.revalidate(project_schema) - return definition.data + return ProjectDefinition(**definition) -def generate_local_override_yml(project: Union[Dict, YAML]) -> YAML: +def generate_local_override_yml( + project: ProjectDefinition, +) -> ProjectDefinition: """ Generates defaults for optional keys in the same YAML structure as the project schema. The returned YAML object can be saved directly to a file, if desired. @@ -76,8 +68,8 @@ def generate_local_override_yml(project: Union[Dict, YAML]) -> YAML: warehouse = conn.warehouse local: dict = {} - if "native_app" in project: - name = clean_identifier(project["native_app"]["name"]) + if project.native_app: + name = clean_identifier(project.native_app.name) app_identifier = to_identifier(name) user_app_identifier = append_to_identifier(app_identifier, f"_{user}") package_identifier = append_to_identifier(app_identifier, f"_pkg_{user}") @@ -90,8 +82,12 @@ def generate_local_override_yml(project: Union[Dict, YAML]) -> YAML: }, "package": {"name": package_identifier, "role": role}, } + # TODO: this is an ugly workaround, because pydantics BaseModel.model_copy(update=) doesn't work properly + # After fixing UpdatableModel.update_from_dict it should be used here + target_definition = project.model_dump() + merge_left(target_definition, local) - return as_document(local, project_override_schema) + return ProjectDefinition(**target_definition) def default_app_package(project_name: str): diff --git a/src/snowflake/cli/api/project/definition_manager.py b/src/snowflake/cli/api/project/definition_manager.py index 663e581563..dd90606163 100644 --- a/src/snowflake/cli/api/project/definition_manager.py +++ b/src/snowflake/cli/api/project/definition_manager.py @@ -7,6 +7,7 @@ from snowflake.cli.api.exceptions import MissingConfiguration from snowflake.cli.api.project.definition import load_project_definition +from snowflake.cli.api.project.schemas.project_definition import ProjectDefinition def _compat_is_mount(path: Path): @@ -100,5 +101,5 @@ def _user_definition_file_if_available(project_path: Path) -> Optional[Path]: ) @functools.cached_property - def project_definition(self) -> dict: + def project_definition(self) -> ProjectDefinition: return load_project_definition(self._project_config_paths) diff --git a/src/snowflake/cli/api/project/schemas/native_app.py b/src/snowflake/cli/api/project/schemas/native_app.py deleted file mode 100644 index 0b6f1048c9..0000000000 --- a/src/snowflake/cli/api/project/schemas/native_app.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from snowflake.cli.api.project.schemas.relaxed_map import FilePath, Glob, RelaxedMap -from snowflake.cli.api.project.util import ( - IDENTIFIER, - SCHEMA_AND_NAME, -) -from strictyaml import Bool, Enum, Optional, Regex, Seq, Str, UniqueSeq - -PathMapping = RelaxedMap( - { - "src": Glob() | Seq(Glob()), - Optional("dest"): FilePath(), - } -) - -native_app_schema = RelaxedMap( - { - "name": Str(), - "artifacts": Seq(FilePath() | PathMapping), - Optional("deploy_root", default="output/deploy/"): FilePath(), - Optional("source_stage", default="app_src.stage"): Regex(SCHEMA_AND_NAME), - Optional("package"): RelaxedMap( - { - Optional("scripts", default=None): UniqueSeq(FilePath()), - Optional("role"): Regex(IDENTIFIER), - Optional("name"): Regex(IDENTIFIER), - Optional("warehouse"): Regex(IDENTIFIER), - Optional("distribution", default="internal"): Enum( - ["internal", "external", "INTERNAL", "EXTERNAL"] - ), - } - ), - Optional("application"): RelaxedMap( - { - Optional("role"): Regex(IDENTIFIER), - Optional("name"): Regex(IDENTIFIER), - Optional("warehouse"): Regex(IDENTIFIER), - Optional("debug", default=True): Bool(), - } - ), - } -) diff --git a/src/snowflake/cli/api/project/schemas/native_app/__init__.py b/src/snowflake/cli/api/project/schemas/native_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowflake/cli/api/project/schemas/native_app/application.py b/src/snowflake/cli/api/project/schemas/native_app/application.py new file mode 100644 index 0000000000..623815dc61 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/native_app/application.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import Field +from snowflake.cli.api.project.schemas.updatable_model import ( + IdentifierField, + UpdatableModel, +) + + +class Application(UpdatableModel): + role: Optional[str] = Field( + title="Role to use when creating the application instance and consumer-side objects", + default=None, + ) + name: Optional[str] = Field( + title="Name of the application created when you run the snow app run command", + default=None, + ) + warehouse: Optional[str] = IdentifierField( + title="Name of the application created when you run the snow app run command", + default=None, + ) + debug: Optional[bool] = Field( + title="Whether to enable debug mode when using a named stage to create an application", + default=True, + ) + + +DistributionOptions = Literal["internal", "external", "INTERNAL", "EXTERNAL"] diff --git a/src/snowflake/cli/api/project/schemas/native_app/native_app.py b/src/snowflake/cli/api/project/schemas/native_app/native_app.py new file mode 100644 index 0000000000..eea465d3d0 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/native_app/native_app.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import re +from typing import List, Optional, Union + +from pydantic import Field, field_validator +from snowflake.cli.api.project.schemas.native_app.application import Application +from snowflake.cli.api.project.schemas.native_app.package import Package +from snowflake.cli.api.project.schemas.native_app.path_maping import PathMapping +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel +from snowflake.cli.api.project.util import ( + SCHEMA_AND_NAME, +) + + +class NativeApp(UpdatableModel): + name: str = Field( + title="Project identifier", + ) + artifacts: List[Union[PathMapping, str]] = Field( + title="List of file source and destination pairs to add to the deploy root", + ) + deploy_root: Optional[str] = Field( + title="Folder at the root of your project where the build step copies the artifacts.", + default="output/deploy/", + ) + source_stage: Optional[str] = Field( + title="Identifier of the stage that stores the application artifacts.", + default="app_src.stage", + ) + package: Optional[Package] = Field(title="PackageSchema", default=None) + application: Optional[Application] = Field(title="Application info", default=None) + + @field_validator("source_stage") + @classmethod + def validate_source_stage(cls, input_value: str): + if not re.match(SCHEMA_AND_NAME, input_value): + raise ValueError("Incorrect value for Native Apps source stage value") + return input_value diff --git a/src/snowflake/cli/api/project/schemas/native_app/package.py b/src/snowflake/cli/api/project/schemas/native_app/package.py new file mode 100644 index 0000000000..3934209a55 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/native_app/package.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field, field_validator +from snowflake.cli.api.project.schemas.native_app.application import DistributionOptions +from snowflake.cli.api.project.schemas.updatable_model import ( + IdentifierField, + UpdatableModel, +) + + +class Package(UpdatableModel): + scripts: Optional[List[str]] = Field( + title="List of SQL file paths relative to the project root", default=None + ) + role: Optional[str] = IdentifierField( + title="Role to use when creating the application package and provider-side objects", + default=None, + ) + name: Optional[str] = IdentifierField( + title="Name of the application created when you run the snow app run command", # TODO: this description seems duplicated, is it ok? + default=None, + ) + warehouse: Optional[str] = IdentifierField( + title="Warehouse used to run the scripts", default=None + ) + distribution: Optional[DistributionOptions] = Field( + title="Distribution of the application package created by the Snowflake CLI", + default="internal", + ) + + @field_validator("scripts") + @classmethod + def validate_scripts(cls, input_list): + if len(input_list) != len(set(input_list)): + raise ValueError( + "Scripts field should contain unique values. Check the list for duplicates and try again" + ) + return input_list diff --git a/src/snowflake/cli/api/project/schemas/native_app/path_maping.py b/src/snowflake/cli/api/project/schemas/native_app/path_maping.py new file mode 100644 index 0000000000..61d2520afa --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/native_app/path_maping.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from typing import Optional + +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel + + +class PathMapping(UpdatableModel): + src: str + dest: Optional[str] = None diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index beb93af268..4e2792e065 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -1,23 +1,25 @@ from __future__ import annotations -from snowflake.cli.api.project.schemas import ( - native_app, - snowpark, - streamlit, -) -from snowflake.cli.api.project.schemas.relaxed_map import RelaxedMap -from strictyaml import ( - Int, - Optional, -) +from typing import Optional -project_schema = RelaxedMap( - { - "definition_version": Int(), - Optional("native_app"): native_app.native_app_schema, - Optional("snowpark"): snowpark.snowpark_schema, - Optional("streamlit"): streamlit.streamlit_schema, - } -) +from pydantic import Field +from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp +from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark +from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel -project_override_schema = project_schema.as_fully_optional() + +class ProjectDefinition(UpdatableModel): + definition_version: int = Field( + title="Version of the project definition schema, which is currently 1" + ) + native_app: Optional[NativeApp] = Field( + title="Native app definitions for the project", default=None + ) + snowpark: Optional[Snowpark] = Field( + title="Snowpark functions and procedures definitions for the project", + default=None, + ) + streamlit: Optional[Streamlit] = Field( + title="Native app definitions for the project", default=None + ) diff --git a/src/snowflake/cli/api/project/schemas/relaxed_map.py b/src/snowflake/cli/api/project/schemas/relaxed_map.py deleted file mode 100644 index bbfe99dfd7..0000000000 --- a/src/snowflake/cli/api/project/schemas/relaxed_map.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -from strictyaml import ( - Any, - Bool, - Decimal, - Int, - MapCombined, - Optional, - Str, -) - -# TODO: use the util regexes to validate paths + globs -FilePath = Str -Glob = Str - - -class RelaxedMap(MapCombined): - """ - A version of a Map that allows any number of unknown key/value pairs. - """ - - def __init__(self, map_validator): - super().__init__( - map_validator, - Str(), - # moves through value validators left-to-right until one matches - Bool() | Decimal() | Int() | Any(), - ) - - def as_fully_optional(self) -> RelaxedMap: - """ - Returns a copy of this schema with all its keys optional, recursing into other - RelaxedMaps we find inside the schema. For existing optional keys, we strip out - the default value and ensure we don't create any new keys. - """ - validator = {} - for key, value in self._validator_dict.items(): - validator[Optional(key)] = ( - value - if not isinstance(value, RelaxedMap) - else value.as_fully_optional() - ) - return RelaxedMap(validator) diff --git a/src/snowflake/cli/api/project/schemas/snowpark.py b/src/snowflake/cli/api/project/schemas/snowpark.py deleted file mode 100644 index ed8d756390..0000000000 --- a/src/snowflake/cli/api/project/schemas/snowpark.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -from snowflake.cli.api.project.schemas.relaxed_map import RelaxedMap -from snowflake.cli.api.project.util import IDENTIFIER -from strictyaml import ( - Bool, - EmptyList, - MapPattern, - Optional, - Regex, - Seq, - Str, -) - -Argument = RelaxedMap({"name": Str(), "type": Str(), Optional("default"): Str()}) - -_callable_mapping = { - "name": Str(), - Optional("database", default=None): Regex(IDENTIFIER), - Optional("schema", default=None): Regex(IDENTIFIER), - "handler": Str(), - "returns": Str(), - "signature": Seq(Argument) | EmptyList(), - Optional("runtime"): Str(), - Optional("external_access_integration"): Seq(Str()), - Optional("secrets"): MapPattern(Str(), Str()), - Optional("imports"): Seq(Str()), -} - -function_schema = RelaxedMap(_callable_mapping) - -procedure_schema = RelaxedMap( - { - **_callable_mapping, - Optional("execute_as_caller"): Bool(), - } -) - -snowpark_schema = RelaxedMap( - { - "project_name": Str(), - "stage_name": Str(), - "src": Str(), - Optional("functions"): Seq(function_schema), - Optional("procedures"): Seq(procedure_schema), - } -) diff --git a/src/snowflake/cli/api/project/schemas/snowpark/__init__.py b/src/snowflake/cli/api/project/schemas/snowpark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowflake/cli/api/project/schemas/snowpark/argument.py b/src/snowflake/cli/api/project/schemas/snowpark/argument.py new file mode 100644 index 0000000000..521925950c --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/snowpark/argument.py @@ -0,0 +1,12 @@ +from typing import Optional + +from pydantic import Field +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel + + +class Argument(UpdatableModel): + name: str = Field(title="Name of the argument") + arg_type: str = Field( + title="Type of the argument", alias="type" + ) # TODO: consider introducing literal/enum here + default: Optional[str] = Field(title="Default value for an argument", default=None) diff --git a/src/snowflake/cli/api/project/schemas/snowpark/callable.py b/src/snowflake/cli/api/project/schemas/snowpark/callable.py new file mode 100644 index 0000000000..ccc36995bc --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/snowpark/callable.py @@ -0,0 +1,66 @@ +from typing import Dict, List, Optional, Union + +from pydantic import Field, field_validator +from snowflake.cli.api.project.schemas.snowpark.argument import Argument +from snowflake.cli.api.project.schemas.updatable_model import ( + IdentifierField, + UpdatableModel, +) + + +class Callable(UpdatableModel): + name: str = Field( + title="Object identifier" + ) # TODO: implement validator. If a name is filly qualified, database and schema cannot be specified + database: Optional[str] = IdentifierField( + title="Name of the database for the function or procedure", default=None + ) + + schema_name: Optional[str] = IdentifierField( + title="Name of the schema for the function or procedure", + default=None, + alias="schema", + ) + handler: str = Field( + title="Function’s or procedure’s implementation of the object inside source module", + examples=["functions.hello_function"], + ) + returns: str = Field( + title="Type of the result" + ) # TODO: again, consider Literal/Enum + signature: Union[str, List[Argument]] = Field( + title="The signature parameter describes consecutive arguments passed to the object" + ) + runtime: Optional[str | float] = Field( + title="Python version to use when executing ", default=None + ) + external_access_integrations: Optional[List[str]] = Field( + title="Names of external access integrations needed for this procedure’s handler code to access external networks", + default=[], + ) + secrets: Optional[Dict[str, str]] = Field( + title="Assigns the names of secrets to variables so that you can use the variables to reference the secrets", + default=[], + ) + imports: Optional[List[str]] = Field( + title="Stage and path to previously uploaded files you want to import", + default=[], + ) + + @field_validator("runtime") + @classmethod + def convert_runtime(cls, runtime_input: str | float) -> str: + if isinstance(runtime_input, float): + return str(runtime_input) + return runtime_input + + +class FunctionSchema(Callable): + pass + + +class ProcedureSchema(Callable): + execute_as_caller: Optional[bool] = Field( + title="Determine whether the procedure is executed with the privileges of the owner (you) or with the privileges of the caller", + default=False, + ) diff --git a/src/snowflake/cli/api/project/schemas/snowpark/snowpark.py b/src/snowflake/cli/api/project/schemas/snowpark/snowpark.py new file mode 100644 index 0000000000..0a3f668450 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/snowpark/snowpark.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field +from snowflake.cli.api.project.schemas.snowpark.callable import ( + FunctionSchema, + ProcedureSchema, +) +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel + + +class Snowpark(UpdatableModel): + project_name: str = Field(title="Project identifier") + stage_name: str = Field(title="Stage in which project’s artifacts will be stored") + src: str = Field(title="Folder where your code should be located") + functions: Optional[List[FunctionSchema]] = Field( + title="List of functions defined in the project", default=[] + ) + procedures: Optional[List[ProcedureSchema]] = Field( + title="List of procedures defined in the project", default=[] + ) diff --git a/src/snowflake/cli/api/project/schemas/streamlit.py b/src/snowflake/cli/api/project/schemas/streamlit.py deleted file mode 100644 index 8283ead836..0000000000 --- a/src/snowflake/cli/api/project/schemas/streamlit.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from snowflake.cli.api.project.schemas.relaxed_map import FilePath, RelaxedMap -from strictyaml import ( - Optional, - Seq, - Str, -) - -streamlit_schema = RelaxedMap( - { - "name": Str(), - Optional("stage", default="streamlit"): Str(), - "query_warehouse": Str(), - Optional("main_file", default="streamlit_app.py"): FilePath(), - Optional("env_file"): FilePath(), - Optional("pages_dir"): FilePath(), - Optional("additional_source_files"): Seq(FilePath()), - } -) diff --git a/src/snowflake/cli/api/project/schemas/streamlit/__init__.py b/src/snowflake/cli/api/project/schemas/streamlit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowflake/cli/api/project/schemas/streamlit/streamlit.py b/src/snowflake/cli/api/project/schemas/streamlit/streamlit.py new file mode 100644 index 0000000000..ce6b5e08c9 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/streamlit/streamlit.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel + + +class Streamlit(UpdatableModel): + name: str = Field(title="App identifier") + stage: Optional[str] = Field( + title="Stage in which the app’s artifacts will be stored", default="streamlit" + ) + query_warehouse: str = Field( + title="Snowflake warehouse to host the app", default="streamlit" + ) + main_file: Optional[str] = Field( + title="Entrypoint file of the streamlit app", default="streamlit_app.py" + ) + env_file: Optional[str] = Field( + title="File defining additional configurations for the app, such as external dependencies", + default=None, + ) + pages_dir: Optional[str] = Field(title="Streamlit pages", default=None) + additional_source_files: Optional[List[str]] = Field( + title="List of additional files which should be included into deployment artifacts", + default=None, + ) diff --git a/src/snowflake/cli/api/project/schemas/updatable_model.py b/src/snowflake/cli/api/project/schemas/updatable_model.py new file mode 100644 index 0000000000..12d8551a9a --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/updatable_model.py @@ -0,0 +1,20 @@ +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, Field +from snowflake.cli.api.project.util import IDENTIFIER_NO_LENGTH + + +class UpdatableModel(BaseModel): + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + def update_from_dict( + self, update_values: Dict[str, Any] + ): # this method works wrong for optional fields set to None + for field, value in update_values.items(): # do we even need this? + if getattr(self, field, None): + setattr(self, field, value) + return self + + +def IdentifierField(*args, **kwargs): # noqa + return Field(max_length=254, pattern=IDENTIFIER_NO_LENGTH, *args, **kwargs) diff --git a/src/snowflake/cli/api/project/util.py b/src/snowflake/cli/api/project/util.py index 61d5109eab..d10646edae 100644 --- a/src/snowflake/cli/api/project/util.py +++ b/src/snowflake/cli/api/project/util.py @@ -4,6 +4,7 @@ from typing import Optional IDENTIFIER = r'((?:"[^"]*(?:""[^"]*)*")|(?:[A-Za-z_][\w$]{0,254}))' +IDENTIFIER_NO_LENGTH = r'((?:"[^"]*(?:""[^"]*)*")|(?:[A-Za-z_][\w$]*))' DB_SCHEMA_AND_NAME = f"{IDENTIFIER}[.]{IDENTIFIER}[.]{IDENTIFIER}" SCHEMA_AND_NAME = f"{IDENTIFIER}[.]{IDENTIFIER}" GLOB_REGEX = r"^[a-zA-Z0-9_\-./*?**\p{L}\p{N}]+$" diff --git a/src/snowflake/cli/api/secure_path.py b/src/snowflake/cli/api/secure_path.py index 11d8d5b6b7..5287fc6dea 100644 --- a/src/snowflake/cli/api/secure_path.py +++ b/src/snowflake/cli/api/secure_path.py @@ -265,7 +265,7 @@ def temporary_directory(cls): Works similarly to tempfile.TemporaryDirectory """ - with tempfile.TemporaryDirectory(prefix="snowcli") as tmpdir: + with tempfile.TemporaryDirectory(prefix="snowflake-cli") as tmpdir: log.info("Created temporary directory %s", tmpdir) yield SecurePath(tmpdir) log.info("Removing temporary directory %s", tmpdir) diff --git a/src/snowflake/cli/api/sql_execution.py b/src/snowflake/cli/api/sql_execution.py index 248c3fa21f..3b2a04fcba 100644 --- a/src/snowflake/cli/api/sql_execution.py +++ b/src/snowflake/cli/api/sql_execution.py @@ -5,9 +5,8 @@ from functools import cached_property from io import StringIO from textwrap import dedent -from typing import Iterable, Optional +from typing import Iterable, Optional, Tuple -from click import ClickException from snowflake.cli.api.cli_global_context import cli_context from snowflake.cli.api.exceptions import ( DatabaseNotProvidedError, @@ -19,6 +18,7 @@ unquote_identifier, ) from snowflake.cli.api.utils.cursor import find_first_row +from snowflake.cli.api.utils.naming_utils import from_qualified_name from snowflake.connector.cursor import DictCursor, SnowflakeCursor from snowflake.connector.errors import ProgrammingError @@ -82,44 +82,27 @@ def use_role(self, new_role: str): if is_different_role: self._execute_query(f"use role {prev_role}") - def _execute_schema_query(self, query: str, **kwargs): - self.check_database_and_schema() - return self._execute_query(query, **kwargs) - - def check_database_and_schema(self) -> None: + def _execute_schema_query(self, query: str, name: Optional[str] = None, **kwargs): """ - Checks if the connection database and schema are set and that they actually exist in Snowflake. + Check that a database and schema are provided before executing the query. Useful for operating on schema level objects. """ - self.check_schema_exists(self._conn.database, self._conn.schema) + self.check_database_and_schema_provided(name) + return self._execute_query(query, **kwargs) - def check_database_exists(self, database: str) -> None: + def check_database_and_schema_provided(self, name: Optional[str] = None) -> None: """ - Checks that database is provided and that it is a valid database in - Snowflake. Note that this could fail for a variety of reasons, - including not authorized to use database, database doesn't exist, - database is not a valid identifier, and more. + Checks if a database and schema are provided, either through the connection context or a qualified name. """ + if name: + _, schema, database = from_qualified_name(name) + else: + schema, database = None, None + schema = schema or self._conn.schema + database = database or self._conn.database if not database: raise DatabaseNotProvidedError() - try: - self._execute_query(f"USE DATABASE {database}") - except ProgrammingError as e: - raise ClickException(f"Exception occurred: {e}.") from e - - def check_schema_exists(self, database: str, schema: str) -> None: - """ - Checks that schema is provided and that it is a valid schema in Snowflake. - Note that this could fail for a variety of reasons, - including not authorized to use schema, schema doesn't exist, - schema is not a valid identifier, and more. - """ - self.check_database_exists(database) if not schema: raise SchemaNotProvidedError() - try: - self._execute_query(f"USE {database}.{schema}") - except ProgrammingError as e: - raise ClickException(f"Exception occurred: {e}.") from e def to_fully_qualified_name( self, name: str, database: Optional[str] = None, schema: Optional[str] = None @@ -131,9 +114,7 @@ def to_fully_qualified_name( if not database: if not self._conn.database: - raise ClickException( - "Default database not specified in connection details." - ) + raise DatabaseNotProvidedError() database = self._conn.database if len(current_parts) == 2: @@ -150,29 +131,65 @@ def get_name_from_fully_qualified_name(name): Returns name of the object from the fully-qualified name. Assumes that [name] is in format [[database.]schema.]name """ - return name.split(".")[-1] + return from_qualified_name(name)[0] + + @staticmethod + def _qualified_name_to_in_clause(name: str) -> Tuple[str, Optional[str]]: + unqualified_name, schema, database = from_qualified_name(name) + if database: + in_clause = f"in schema {database}.{schema}" + elif schema: + in_clause = f"in schema {schema}" + else: + in_clause = None + return unqualified_name, in_clause + + class InClauseWithQualifiedNameError(ValueError): + def __init__(self): + super().__init__("non-empty 'in_clause' passed with qualified 'name'") def show_specific_object( self, object_type_plural: str, - unqualified_name: str, + name: str, name_col: str = "name", in_clause: str = "", check_schema: bool = False, ) -> Optional[dict]: """ Executes a "show like" query for a particular entity with a - given (unqualified) name. This command is useful when the corresponding + given (optionally qualified) name. This command is useful when the corresponding "describe " query does not provide the information you seek. + + Note that this command is analogous to describe and should only return a single row. + If the target object type is a schema level object, then check_schema should be set to True + so that the function will verify that a database and schema are provided, either through + the connection or a qualified name, before executing the query. """ - if check_schema: - self.check_database_and_schema() + + unqualified_name, name_in_clause = self._qualified_name_to_in_clause(name) + if in_clause and name_in_clause: + raise self.InClauseWithQualifiedNameError() + elif name_in_clause: + in_clause = name_in_clause show_obj_query = f"show {object_type_plural} like {identifier_to_show_like_pattern(unqualified_name)} {in_clause}".strip() - show_obj_cursor = self._execute_query( # type: ignore - show_obj_query, cursor_class=DictCursor - ) + + if check_schema: + show_obj_cursor = self._execute_schema_query( # type: ignore + show_obj_query, name=name, cursor_class=DictCursor + ) + else: + show_obj_cursor = self._execute_query( # type: ignore + show_obj_query, cursor_class=DictCursor + ) + if show_obj_cursor.rowcount is None: raise SnowflakeSQLExecutionError(show_obj_query) + elif show_obj_cursor.rowcount > 1: + raise ProgrammingError( + f"Received multiple rows from result of SQL statement: {show_obj_query}. Usage of 'show_specific_object' may not be properly scoped." + ) + show_obj_row = find_first_row( show_obj_cursor, lambda row: row[name_col] == unquote_identifier(unqualified_name), diff --git a/src/snowflake/cli/api/utils/naming_utils.py b/src/snowflake/cli/api/utils/naming_utils.py new file mode 100644 index 0000000000..895698cc6b --- /dev/null +++ b/src/snowflake/cli/api/utils/naming_utils.py @@ -0,0 +1,27 @@ +import re +from typing import Optional, Tuple + +from snowflake.cli.api.project.util import ( + VALID_IDENTIFIER_REGEX, +) + + +def from_qualified_name(name: str) -> Tuple[str, Optional[str], Optional[str]]: + """ + Takes in an object name in the form [[database.]schema.]name. Returns a tuple (name, [schema], [database]) + """ + # TODO: Use regex to match object name to a valid identifier or valid identifier (args). Second case is for sprocs and UDFs + qualifier_pattern = rf"(?:(?P{VALID_IDENTIFIER_REGEX})\.)?(?:(?P{VALID_IDENTIFIER_REGEX})\.)?(?P.*)" + result = re.fullmatch(qualifier_pattern, name) + + if result is None: + raise ValueError(f"'{name}' is not a valid qualified name") + + unqualified_name = result.group("name") + if result.group("second_qualifier") is not None: + database = result.group("first_qualifier") + schema = result.group("second_qualifier") + else: + database = None + schema = result.group("first_qualifier") + return unqualified_name, schema, database diff --git a/src/snowflake/cli/app/__main__.py b/src/snowflake/cli/app/__main__.py index 68aadbe2f4..c187979075 100644 --- a/src/snowflake/cli/app/__main__.py +++ b/src/snowflake/cli/app/__main__.py @@ -2,10 +2,11 @@ import sys -from snowflake.cli.app.cli_app import app +from snowflake.cli.app.cli_app import app_factory def main(*args): + app = app_factory() app(*args) diff --git a/src/snowflake/cli/app/cli_app.py b/src/snowflake/cli/app/cli_app.py index cca66eb61c..a7e22e08e9 100644 --- a/src/snowflake/cli/app/cli_app.py +++ b/src/snowflake/cli/app/cli_app.py @@ -28,10 +28,9 @@ setup_pycharm_remote_debugger_if_provided, ) from snowflake.cli.app.main_typer import SnowCliMainTyper -from snowflake.cli.app.printing import print_result +from snowflake.cli.app.printing import MessageResult, print_result from snowflake.connector.config_manager import CONFIG_MANAGER -app: SnowCliMainTyper = SnowCliMainTyper() log = logging.getLogger(__name__) _api = Api(plugin_config_provider=PluginConfigProviderImpl()) @@ -104,7 +103,7 @@ def _commands_structure_callback(value: bool): @_do_not_execute_on_completion def _version_callback(value: bool): if value: - typer.echo(f"Snowflake CLI version: {__about__.VERSION}") + print_result(MessageResult(f"Snowflake CLI version: {__about__.VERSION}")) _exit_with_cleanup() @@ -126,93 +125,98 @@ def _info_callback(value: bool): _exit_with_cleanup() -@app.callback() -def default( - version: bool = typer.Option( - None, - "--version", - help="Shows version of the Snowflake CLI", - callback=_version_callback, - is_eager=True, - ), - docs: bool = typer.Option( - None, - "--docs", - hidden=True, - help="Generates Snowflake CLI documentation", - callback=_docs_callback, - is_eager=True, - ), - structure: bool = typer.Option( - None, - "--structure", - hidden=True, - help="Prints Snowflake CLI structure of commands", - callback=_commands_structure_callback, - is_eager=True, - ), - info: bool = typer.Option( - None, - "--info", - help="Shows information about the Snowflake CLI", - callback=_info_callback, - ), - configuration_file: Path = typer.Option( - None, - "--config-file", - help="Specifies Snowflake CLI configuration file that should be used", - exists=True, - dir_okay=False, - is_eager=True, - callback=_config_init_callback, - ), - pycharm_debug_library_path: str = typer.Option( - None, - "--pycharm-debug-library-path", - hidden=True, - ), - pycharm_debug_server_host: str = typer.Option( - "localhost", - "--pycharm-debug-server-host", - hidden=True, - ), - pycharm_debug_server_port: int = typer.Option( - 12345, - "--pycharm-debug-server-port", - hidden=True, - ), - disable_external_command_plugins: bool = typer.Option( - None, - "--disable-external-command-plugins", - help="Disable external command plugins", - callback=_disable_external_command_plugins_callback, - is_eager=True, - hidden=True, - ), - # THIS OPTION SHOULD BE THE LAST OPTION IN THE LIST! - # --- - # This is a hidden artificial option used only to guarantee execution of commands registration - # and make this guaranty not dependent on other callbacks. - # Commands registration is invoked as soon as all callbacks - # decorated with "_commands_registration.before" are executed - # but if there are no such callbacks (at the result of possible future changes) - # then we need to invoke commands registration manually. - # - # This option is also responsible for resetting registration state for test purposes. - commands_registration: bool = typer.Option( - True, - "--commands-registration", - help="Commands registration", - hidden=True, - is_eager=True, - callback=_commands_registration_callback, - ), -) -> None: - """ - Snowflake CLI tool for developers. - """ - setup_pycharm_remote_debugger_if_provided( - pycharm_debug_library_path=pycharm_debug_library_path, - pycharm_debug_server_host=pycharm_debug_server_host, - pycharm_debug_server_port=pycharm_debug_server_port, - ) +def app_factory() -> SnowCliMainTyper: + app = SnowCliMainTyper() + + @app.callback() + def default( + version: bool = typer.Option( + None, + "--version", + help="Shows version of the Snowflake CLI", + callback=_version_callback, + is_eager=True, + ), + docs: bool = typer.Option( + None, + "--docs", + hidden=True, + help="Generates Snowflake CLI documentation", + callback=_docs_callback, + is_eager=True, + ), + structure: bool = typer.Option( + None, + "--structure", + hidden=True, + help="Prints Snowflake CLI structure of commands", + callback=_commands_structure_callback, + is_eager=True, + ), + info: bool = typer.Option( + None, + "--info", + help="Shows information about the Snowflake CLI", + callback=_info_callback, + ), + configuration_file: Path = typer.Option( + None, + "--config-file", + help="Specifies Snowflake CLI configuration file that should be used", + exists=True, + dir_okay=False, + is_eager=True, + callback=_config_init_callback, + ), + pycharm_debug_library_path: str = typer.Option( + None, + "--pycharm-debug-library-path", + hidden=True, + ), + pycharm_debug_server_host: str = typer.Option( + "localhost", + "--pycharm-debug-server-host", + hidden=True, + ), + pycharm_debug_server_port: int = typer.Option( + 12345, + "--pycharm-debug-server-port", + hidden=True, + ), + disable_external_command_plugins: bool = typer.Option( + None, + "--disable-external-command-plugins", + help="Disable external command plugins", + callback=_disable_external_command_plugins_callback, + is_eager=True, + hidden=True, + ), + # THIS OPTION SHOULD BE THE LAST OPTION IN THE LIST! + # --- + # This is a hidden artificial option used only to guarantee execution of commands registration + # and make this guaranty not dependent on other callbacks. + # Commands registration is invoked as soon as all callbacks + # decorated with "_commands_registration.before" are executed + # but if there are no such callbacks (at the result of possible future changes) + # then we need to invoke commands registration manually. + # + # This option is also responsible for resetting registration state for test purposes. + commands_registration: bool = typer.Option( + True, + "--commands-registration", + help="Commands registration", + hidden=True, + is_eager=True, + callback=_commands_registration_callback, + ), + ) -> None: + """ + Snowflake CLI tool for developers. + """ + setup_pycharm_remote_debugger_if_provided( + pycharm_debug_library_path=pycharm_debug_library_path, + pycharm_debug_server_host=pycharm_debug_server_host, + pycharm_debug_server_port=pycharm_debug_server_port, + ) + + return app diff --git a/src/snowflake/cli/app/loggers.py b/src/snowflake/cli/app/loggers.py index c9c43c62b6..68b6ccb4f3 100644 --- a/src/snowflake/cli/app/loggers.py +++ b/src/snowflake/cli/app/loggers.py @@ -10,7 +10,7 @@ from snowflake.cli.api.exceptions import InvalidLogsConfiguration from snowflake.cli.api.secure_path import SecurePath -_DEFAULT_LOG_FILENAME = "snowcli.log" +_DEFAULT_LOG_FILENAME = "snowflake-cli.log" @dataclass diff --git a/src/snowflake/cli/app/main_typer.py b/src/snowflake/cli/app/main_typer.py index aeae62b2cc..3498d91f1c 100644 --- a/src/snowflake/cli/app/main_typer.py +++ b/src/snowflake/cli/app/main_typer.py @@ -3,16 +3,16 @@ import sys import typer -from rich import print as rich_print from snowflake.cli.api.cli_global_context import cli_context from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS, DebugOption +from snowflake.cli.api.console import cli_console def _handle_exception(exception: Exception): if cli_context.enable_tracebacks: raise exception else: - rich_print( + cli_console.warning( "\nAn unexpected exception occurred. Use --debug option to see the traceback. Exception message:\n\n" + exception.__str__() ) diff --git a/src/snowflake/cli/plugins/connection/commands.py b/src/snowflake/cli/plugins/connection/commands.py index a92f9fdbf8..14bfd54982 100644 --- a/src/snowflake/cli/plugins/connection/commands.py +++ b/src/snowflake/cli/plugins/connection/commands.py @@ -2,7 +2,6 @@ import logging -import click import typer from click import ClickException, Context, Parameter # type: ignore from click.core import ParameterSource # type: ignore @@ -21,6 +20,7 @@ get_connection, set_config_value, ) +from snowflake.cli.api.console import cli_console from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.output.types import ( CollectionResult, @@ -81,7 +81,7 @@ def callback(value: str): def _password_callback(ctx: Context, param: Parameter, value: str): if value and ctx.get_parameter_source(param.name) == ParameterSource.COMMANDLINE: # type: ignore - click.echo(PLAIN_PASSWORD_MSG) + cli_console.warning(PLAIN_PASSWORD_MSG) return value diff --git a/src/snowflake/cli/plugins/nativeapp/artifacts.py b/src/snowflake/cli/plugins/nativeapp/artifacts.py index 282d11b105..12bdfe6e52 100644 --- a/src/snowflake/cli/plugins/nativeapp/artifacts.py +++ b/src/snowflake/cli/plugins/nativeapp/artifacts.py @@ -6,6 +6,7 @@ import strictyaml from click import ClickException from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB +from snowflake.cli.api.project.schemas.native_app.path_maping import PathMapping from snowflake.cli.api.secure_path import SecurePath @@ -153,8 +154,8 @@ def translate_artifact(item: Union[dict, str]) -> ArtifactMapping: Validation is done later when we actually resolve files / folders. """ - if isinstance(item, dict): - return ArtifactMapping(item["src"], item.get("dest", item["src"])) + if isinstance(item, PathMapping): + return ArtifactMapping(item.src, item.dest if item.dest else item.src) elif isinstance(item, str): return ArtifactMapping(item, item) diff --git a/src/snowflake/cli/plugins/nativeapp/manager.py b/src/snowflake/cli/plugins/nativeapp/manager.py index b6675ad14f..ff45918685 100644 --- a/src/snowflake/cli/plugins/nativeapp/manager.py +++ b/src/snowflake/cli/plugins/nativeapp/manager.py @@ -4,7 +4,7 @@ from functools import cached_property from pathlib import Path from textwrap import dedent -from typing import Dict, List, Optional +from typing import List, Optional from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError @@ -13,6 +13,7 @@ default_application, default_role, ) +from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.api.project.util import ( extract_schema, to_identifier, @@ -99,7 +100,7 @@ class NativeAppManager(SqlExecutionMixin): Base class with frequently used functionality already implemented and ready to be used by related subclasses. """ - def __init__(self, project_definition: Dict, project_root: Path): + def __init__(self, project_definition: NativeApp, project_root: Path): super().__init__() self._project_root = project_root self._project_definition = project_definition @@ -109,27 +110,30 @@ def project_root(self) -> Path: return self._project_root @property - def definition(self) -> Dict: + def definition(self) -> NativeApp: return self._project_definition @cached_property def artifacts(self) -> List[ArtifactMapping]: - return [translate_artifact(item) for item in self.definition["artifacts"]] + return [translate_artifact(item) for item in self.definition.artifacts] @cached_property def deploy_root(self) -> Path: - return Path(self.project_root, self.definition["deploy_root"]) + return Path(self.project_root, self.definition.deploy_root) @cached_property def package_scripts(self) -> List[str]: """ Relative paths to package scripts from the project root. """ - return self.definition.get("package", {}).get("scripts", []) + if self.definition.package and self.definition.package.scripts: + return self.definition.package.scripts + else: + return [] @cached_property def stage_fqn(self) -> str: - return f'{self.package_name}.{self.definition["source_stage"]}' + return f"{self.package_name}.{self.definition.source_stage}" @cached_property def stage_schema(self) -> Optional[str]: @@ -137,55 +141,65 @@ def stage_schema(self) -> Optional[str]: @cached_property def package_warehouse(self) -> Optional[str]: - return self.definition.get("package", {}).get("warehouse", self._conn.warehouse) + if self.definition.package and self.definition.package.warehouse: + return self.definition.package.warehouse + else: + return self._conn.warehouse @cached_property def application_warehouse(self) -> Optional[str]: - return self.definition.get("application", {}).get( - "warehouse", self._conn.warehouse - ) + if self.definition.application and self.definition.application.warehouse: + return self.definition.application.warehouse + else: + return self._conn.warehouse @cached_property def project_identifier(self) -> str: # name is expected to be a valid Snowflake identifier, but PyYAML # will sometimes strip out double quotes so we try to get them back here. - return to_identifier(self.definition["name"]) + return to_identifier(self.definition.name) @cached_property def package_name(self) -> str: - return to_identifier( - self.definition.get("package", {}).get( - "name", default_app_package(self.project_identifier) - ) - ) + if self.definition.package and self.definition.package.name: + return to_identifier(self.definition.package.name) + else: + return to_identifier(default_app_package(self.project_identifier)) @cached_property def package_role(self) -> str: - return self.definition.get("package", {}).get("role", None) or default_role() + if self.definition.package and self.definition.package.role: + return self.definition.package.role + else: + return default_role() @cached_property def package_distribution(self) -> str: - return ( - self.definition.get("package", {}).get("distribution", "internal").lower() - ) + if self.definition.package and self.definition.package.distribution: + return self.definition.package.distribution.lower() + else: + return "internal" @cached_property def app_name(self) -> str: - return to_identifier( - self.definition.get("application", {}).get( - "name", default_application(self.project_identifier) - ) - ) + if self.definition.application and self.definition.application.name: + return to_identifier(self.definition.application.name) + else: + return to_identifier(default_application(self.project_identifier)) @cached_property def app_role(self) -> str: - return ( - self.definition.get("application", {}).get("role", None) or default_role() - ) + if self.definition.application and self.definition.application.role: + return self.definition.application.role + else: + return default_role() @cached_property def debug_mode(self) -> bool: - return self.definition.get("application", {}).get("debug", True) + if self.definition.application: + return self.definition.application.debug + else: + return True @cached_property def get_app_pkg_distribution_in_snowflake(self) -> str: diff --git a/src/snowflake/cli/plugins/nativeapp/run_processor.py b/src/snowflake/cli/plugins/nativeapp/run_processor.py index b514dceb33..8c13310ee0 100644 --- a/src/snowflake/cli/plugins/nativeapp/run_processor.py +++ b/src/snowflake/cli/plugins/nativeapp/run_processor.py @@ -1,12 +1,13 @@ from pathlib import Path from textwrap import dedent -from typing import Dict, Optional +from typing import Optional import jinja2 import typer from click import UsageError from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError +from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.plugins.nativeapp.constants import ( ALLOWED_SPECIAL_COMMENTS, COMMENT_COL, @@ -38,7 +39,7 @@ class NativeAppRunProcessor(NativeAppManager, NativeAppCommandProcessor): - def __init__(self, project_definition: Dict, project_root: Path): + def __init__(self, project_definition: NativeApp, project_root: Path): super().__init__(project_definition, project_root) def create_app_package(self) -> None: diff --git a/src/snowflake/cli/plugins/nativeapp/version/version_processor.py b/src/snowflake/cli/plugins/nativeapp/version/version_processor.py index 399c337b7f..574fa8e7f3 100644 --- a/src/snowflake/cli/plugins/nativeapp/version/version_processor.py +++ b/src/snowflake/cli/plugins/nativeapp/version/version_processor.py @@ -6,6 +6,7 @@ from click import BadOptionUsage, ClickException from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError +from snowflake.cli.api.project.schemas.native_app.native_app import NativeApp from snowflake.cli.api.project.util import unquote_identifier from snowflake.cli.api.utils.cursor import ( find_all_rows, @@ -240,7 +241,7 @@ def process( class NativeAppVersionDropProcessor(NativeAppManager, NativeAppCommandProcessor): - def __init__(self, project_definition: Dict, project_root: Path): + def __init__(self, project_definition: NativeApp, project_root: Path): super().__init__(project_definition, project_root) def process( diff --git a/src/snowflake/cli/plugins/object/common.py b/src/snowflake/cli/plugins/object/common.py index b53ac670e8..43424f9bbc 100644 --- a/src/snowflake/cli/plugins/object/common.py +++ b/src/snowflake/cli/plugins/object/common.py @@ -4,8 +4,7 @@ from click import ClickException from snowflake.cli.api.commands.flags import OverrideableOption from snowflake.cli.api.project.util import ( - QUOTED_IDENTIFIER_REGEX, - UNQUOTED_IDENTIFIER_REGEX, + VALID_IDENTIFIER_REGEX, is_valid_identifier, to_string_literal, ) @@ -34,11 +33,9 @@ def __init__(self): def _parse_tag(tag: str) -> Tag: import re - identifier_pattern = re.compile( - f"(?P{UNQUOTED_IDENTIFIER_REGEX}|{QUOTED_IDENTIFIER_REGEX})" - ) - value_pattern = re.compile(f"(?P.+)") - result = re.fullmatch(f"{identifier_pattern.pattern}={value_pattern.pattern}", tag) + identifier_pattern = rf"(?P{VALID_IDENTIFIER_REGEX})" + value_pattern = r"(?P.+)" + result = re.fullmatch(rf"{identifier_pattern}={value_pattern}", tag) if result is not None: try: return Tag(result.group("tag_name"), result.group("tag_value")) diff --git a/src/snowflake/cli/plugins/snowpark/commands.py b/src/snowflake/cli/plugins/snowpark/commands.py index ec95b51c4b..09634c6b88 100644 --- a/src/snowflake/cli/plugins/snowpark/commands.py +++ b/src/snowflake/cli/plugins/snowpark/commands.py @@ -12,6 +12,7 @@ with_project_definition, ) from snowflake.cli.api.commands.flags import ( + ReplaceOption, execution_identifier_argument, ) from snowflake.cli.api.commands.project_initialisation import add_init_command @@ -26,6 +27,12 @@ MessageResult, SingleQueryResult, ) +from snowflake.cli.api.project.schemas.snowpark.callable import ( + Callable, + FunctionSchema, + ProcedureSchema, +) +from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark from snowflake.cli.plugins.object.manager import ObjectManager from snowflake.cli.plugins.object.stage.manager import StageManager from snowflake.cli.plugins.snowpark.common import ( @@ -50,12 +57,6 @@ help="Manages procedures and functions.", ) -ReplaceOption = typer.Option( - False, - "--replace", - help="Replaces procedure or function, even if no detected changes to metadata", -) - ObjectTypeArgument = typer.Argument( help="Type of Snowpark object", case_sensitive=False, @@ -67,7 +68,9 @@ @app.command("deploy", requires_connection=True) @with_project_definition("snowpark") def deploy( - replace: bool = ReplaceOption, + replace: bool = ReplaceOption( + help="Replaces procedure or function, even if no detected changes to metadata" + ), **options, ) -> CommandResult: """ @@ -77,8 +80,8 @@ def deploy( """ snowpark = cli_context.project_definition - procedures = snowpark.get("procedures", []) - functions = snowpark.get("functions", []) + procedures = snowpark.procedures + functions = snowpark.functions if not procedures and not functions: raise ClickException( @@ -112,16 +115,16 @@ def deploy( raise ClickException(msg) # Create stage - stage_name = snowpark.get("stage_name", DEPLOYMENT_STAGE) + stage_name = snowpark.stage_name stage_manager = StageManager() stage_name = stage_manager.to_fully_qualified_name(stage_name) stage_manager.create( - stage_name=stage_name, comment="deployments managed by snowcli" + stage_name=stage_name, comment="deployments managed by Snowflake CLI" ) packages = get_snowflake_packages() - artifact_stage_directory = get_app_stage_path(stage_name, snowpark["project_name"]) + artifact_stage_directory = get_app_stage_path(stage_name, snowpark.project_name) artifact_stage_target = f"{artifact_stage_directory}/{build_artifact_path.name}" stage_manager.put( @@ -158,11 +161,13 @@ def deploy( return CollectionResult(deploy_status) -def _assert_object_definitions_are_correct(object_type, object_definitions): +def _assert_object_definitions_are_correct( + object_type, object_definitions: List[Callable] +): for definition in object_definitions: - database = definition.get("database") - schema = definition.get("schema") - name = definition["name"] + database = definition.database + schema = definition.schema_name + name = definition.name fqn_parts = len(name.split(".")) if fqn_parts == 3 and database: raise ClickException( @@ -196,7 +201,9 @@ def _find_existing_objects( def _check_if_all_defined_integrations_exists( - om: ObjectManager, functions: List[Dict], procedures: List[Dict] + om: ObjectManager, + functions: List[FunctionSchema], + procedures: List[ProcedureSchema], ): existing_integrations = { i["name"].lower() @@ -206,14 +213,12 @@ def _check_if_all_defined_integrations_exists( declared_integration: Set[str] = set() for object_definition in [*functions, *procedures]: external_access_integrations = { - s.lower() for s in object_definition.get("external_access_integrations", []) + s.lower() for s in object_definition.external_access_integrations } - secrets = [s.lower() for s in object_definition.get("secrets", [])] + secrets = [s.lower() for s in object_definition.secrets] if not external_access_integrations and secrets: - raise SecretsWithoutExternalAccessIntegrationError( - object_definition["name"] - ) + raise SecretsWithoutExternalAccessIntegrationError(object_definition.name) declared_integration = declared_integration | external_access_integrations @@ -232,7 +237,7 @@ def get_app_stage_path(stage_name: Optional[str], project_name: str) -> str: def _deploy_single_object( manager: FunctionManager | ProcedureManager, object_type: ObjectType, - object_definition: Dict, + object_definition: Callable, existing_objects: Dict[str, Dict], packages: List[str], stage_artifact_path: str, @@ -248,8 +253,8 @@ def _deploy_single_object( ) log.info("Deploying %s: %s", object_type, identifier_with_default_values) - handler = object_definition["handler"] - returns = object_definition["returns"] + handler = object_definition.handler + returns = object_definition.returns replace_object = False object_exists = identifier in existing_objects @@ -274,18 +279,15 @@ def _deploy_single_object( "return_type": returns, "artifact_file": stage_artifact_path, "packages": packages, - "runtime": object_definition.get("runtime"), - "external_access_integrations": object_definition.get( - "external_access_integrations" - ), - "secrets": object_definition.get("secrets"), - "imports": object_definition.get("imports", []), + "runtime": object_definition.runtime, + "external_access_integrations": object_definition.external_access_integrations, + "secrets": object_definition.secrets, + "imports": object_definition.imports, } if object_type == ObjectType.PROCEDURE: - create_or_replace_kwargs["execute_as_caller"] = object_definition.get( + create_or_replace_kwargs[ "execute_as_caller" - ) - + ] = object_definition.execute_as_caller manager.create_or_replace(**create_or_replace_kwargs) status = "created" if not object_exists else "definition updated" @@ -296,8 +298,8 @@ def _deploy_single_object( } -def _get_snowpark_artifact_path(snowpark_definition: Dict): - source = Path(snowpark_definition["src"]) +def _get_snowpark_artifact_path(snowpark_definition: Snowpark): + source = Path(snowpark_definition.src) artifact_file = Path.cwd() / (source.name + ".zip") return artifact_file @@ -315,7 +317,7 @@ def build( The archive is built using only the `src` directory specified in the project file. """ snowpark = cli_context.project_definition - source = Path(snowpark.get("src")) + source = Path(snowpark.src) artifact_file = _get_snowpark_artifact_path(snowpark) log.info("Building package using sources from: %s", source.resolve()) diff --git a/src/snowflake/cli/plugins/snowpark/common.py b/src/snowflake/cli/plugins/snowpark/common.py index a6d70af122..11bbbe503c 100644 --- a/src/snowflake/cli/plugins/snowpark/common.py +++ b/src/snowflake/cli/plugins/snowpark/common.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB, ObjectType +from snowflake.cli.api.project.schemas.snowpark.argument import Argument from snowflake.cli.api.secure_path import SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.cli.plugins.snowpark.package_utils import generate_deploy_stage_name @@ -173,26 +174,26 @@ def _is_signature_type_a_string(sig_type: str) -> bool: def build_udf_sproc_identifier( - udf_sproc_dict, + udf_sproc, slq_exec_mixin, include_parameter_names, include_default_values=False, ): - def format_arg(arg): - result = f"{arg['type']}" + def format_arg(arg: Argument): + result = f"{arg.arg_type}" if include_parameter_names: - result = f"{arg['name']} {result}" - if include_default_values and "default" in arg: - val = f"{arg['default']}" - if _is_signature_type_a_string(arg["type"]): + result = f"{arg.name} {result}" + if include_default_values and arg.default: + val = f"{arg.default}" + if _is_signature_type_a_string(arg.arg_type): val = f"'{val}'" result += f" default {val}" return result - arguments = ", ".join(format_arg(arg) for arg in udf_sproc_dict["signature"]) + arguments = ", ".join(format_arg(arg) for arg in udf_sproc.signature) name = slq_exec_mixin.to_fully_qualified_name( - udf_sproc_dict["name"], - database=udf_sproc_dict.get("database"), - schema=udf_sproc_dict.get("schema"), + udf_sproc.name, + database=udf_sproc.database, + schema=udf_sproc.schema_name, ) return f"{name}({arguments})" diff --git a/src/snowflake/cli/plugins/spcs/common.py b/src/snowflake/cli/plugins/spcs/common.py index a4bb5acded..07d8a4ce2a 100644 --- a/src/snowflake/cli/plugins/spcs/common.py +++ b/src/snowflake/cli/plugins/spcs/common.py @@ -66,12 +66,16 @@ def validate_and_set_instances(min_instances, max_instances, instance_name): def handle_object_already_exists( - error: ProgrammingError, object_type: ObjectType, object_name: str + error: ProgrammingError, + object_type: ObjectType, + object_name: str, + replace_available: bool = False, ): if error.errno == 2002: raise ObjectAlreadyExistsError( object_type=object_type, name=unquote_identifier(object_name), + replace_available=replace_available, ) else: raise error diff --git a/src/snowflake/cli/plugins/spcs/compute_pool/commands.py b/src/snowflake/cli/plugins/spcs/compute_pool/commands.py index 05979e052f..f069f4a2c3 100644 --- a/src/snowflake/cli/plugins/spcs/compute_pool/commands.py +++ b/src/snowflake/cli/plugins/spcs/compute_pool/commands.py @@ -2,9 +2,7 @@ import typer from click import ClickException -from snowflake.cli.api.commands.flags import ( - OverrideableOption, -) +from snowflake.cli.api.commands.flags import IfNotExistsOption, OverrideableOption from snowflake.cli.api.commands.snow_typer import SnowTyper from snowflake.cli.api.output.types import CommandResult, SingleQueryResult from snowflake.cli.api.project.util import is_valid_object_name @@ -90,6 +88,7 @@ def create( ), auto_suspend_secs: int = AutoSuspendSecsOption(), comment: Optional[str] = CommentOption(help=_COMMENT_HELP), + if_not_exists: bool = IfNotExistsOption(), **options, ) -> CommandResult: """ @@ -105,6 +104,7 @@ def create( initially_suspended=initially_suspended, auto_suspend_secs=auto_suspend_secs, comment=comment, + if_not_exists=if_not_exists, ) return SingleQueryResult(cursor) diff --git a/src/snowflake/cli/plugins/spcs/compute_pool/manager.py b/src/snowflake/cli/plugins/spcs/compute_pool/manager.py index 416e16020a..4e7720e271 100644 --- a/src/snowflake/cli/plugins/spcs/compute_pool/manager.py +++ b/src/snowflake/cli/plugins/spcs/compute_pool/manager.py @@ -22,9 +22,13 @@ def create( initially_suspended: bool, auto_suspend_secs: int, comment: Optional[str], + if_not_exists: bool, ) -> SnowflakeCursor: + create_statement = "CREATE COMPUTE POOL" + if if_not_exists: + create_statement = f"{create_statement} IF NOT EXISTS" query = f"""\ - CREATE COMPUTE POOL {pool_name} + {create_statement} {pool_name} MIN_NODES = {min_nodes} MAX_NODES = {max_nodes} INSTANCE_FAMILY = {instance_family} diff --git a/src/snowflake/cli/plugins/spcs/image_repository/commands.py b/src/snowflake/cli/plugins/spcs/image_repository/commands.py index 301e77b2bc..b765ba142c 100644 --- a/src/snowflake/cli/plugins/spcs/image_repository/commands.py +++ b/src/snowflake/cli/plugins/spcs/image_repository/commands.py @@ -4,7 +4,9 @@ import requests import typer from click import ClickException +from snowflake.cli.api.commands.flags import IfNotExistsOption, ReplaceOption from snowflake.cli.api.commands.snow_typer import SnowTyper +from snowflake.cli.api.console import cli_console from snowflake.cli.api.output.types import ( CollectionResult, MessageResult, @@ -22,7 +24,7 @@ def _repo_name_callback(name: str): - if not is_valid_object_name(name, max_depth=0, allow_quoted=True): + if not is_valid_object_name(name, max_depth=2, allow_quoted=False): raise ClickException( f"'{name}' is not a valid image repository name. Note that image repository names must be unquoted identifiers. The same constraint also applies to database and schema names where you create an image repository." ) @@ -38,12 +40,18 @@ def _repo_name_callback(name: str): @app.command(requires_connection=True) def create( name: str = REPO_NAME_ARGUMENT, + replace: bool = ReplaceOption(), + if_not_exists: bool = IfNotExistsOption(), **options, ): """ Creates a new image repository in the current schema. """ - return SingleQueryResult(ImageRepositoryManager().create(name=name)) + return SingleQueryResult( + ImageRepositoryManager().create( + name=name, replace=replace, if_not_exists=if_not_exists + ) + ) @app.command("list-images", requires_connection=True) @@ -119,7 +127,7 @@ def list_tags( ) if response.status_code != 200: - print("Call to the registry failed", response.text) + cli_console.warning(f"Call to the registry failed {response.text}") data = json.loads(response.text) if "tags" in data: diff --git a/src/snowflake/cli/plugins/spcs/image_repository/manager.py b/src/snowflake/cli/plugins/spcs/image_repository/manager.py index f04cd52cac..4bb6ef4163 100644 --- a/src/snowflake/cli/plugins/spcs/image_repository/manager.py +++ b/src/snowflake/cli/plugins/spcs/image_repository/manager.py @@ -1,10 +1,6 @@ from urllib.parse import urlparse -from click import ClickException from snowflake.cli.api.constants import ObjectType -from snowflake.cli.api.project.util import ( - is_valid_unquoted_identifier, -) from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.cli.plugins.spcs.common import handle_object_already_exists from snowflake.connector.errors import ProgrammingError @@ -21,17 +17,13 @@ def get_role(self): return self._conn.role def get_repository_url(self, repo_name: str, with_scheme: bool = True): - if not is_valid_unquoted_identifier(repo_name): - raise ValueError( - f"repo_name '{repo_name}' is not a valid unquoted Snowflake identifier" - ) - # we explicitly do not allow this function to be used without connection database and schema set + repo_row = self.show_specific_object( "image repositories", repo_name, check_schema=True ) if repo_row is None: - raise ClickException( - f"Image repository '{repo_name}' does not exist in database '{self.get_database()}' and schema '{self.get_schema()}' or not authorized." + raise ProgrammingError( + f"Image repository '{self.to_fully_qualified_name(repo_name)}' does not exist or not authorized." ) if with_scheme: return f"https://{repo_row['repository_url']}" @@ -51,8 +43,26 @@ def get_repository_api_url(self, repo_url): return f"{scheme}://{host}/v2{path}" - def create(self, name: str): + def create( + self, + name: str, + if_not_exists: bool, + replace: bool, + ): + if if_not_exists and replace: + raise ValueError( + "'replace' and 'if_not_exists' options are mutually exclusive for ImageRepositoryManager.create" + ) + elif replace: + create_statement = "create or replace image repository" + elif if_not_exists: + create_statement = "create image repository if not exists" + else: + create_statement = "create image repository" + try: - return self._execute_schema_query(f"create image repository {name}") + return self._execute_schema_query(f"{create_statement} {name}", name=name) except ProgrammingError as e: - handle_object_already_exists(e, ObjectType.IMAGE_REPOSITORY, name) + handle_object_already_exists( + e, ObjectType.IMAGE_REPOSITORY, name, replace_available=True + ) diff --git a/src/snowflake/cli/plugins/spcs/services/commands.py b/src/snowflake/cli/plugins/spcs/services/commands.py index b9acfe4959..976e9b0bd2 100644 --- a/src/snowflake/cli/plugins/spcs/services/commands.py +++ b/src/snowflake/cli/plugins/spcs/services/commands.py @@ -4,9 +4,7 @@ import typer from click import ClickException -from snowflake.cli.api.commands.flags import ( - OverrideableOption, -) +from snowflake.cli.api.commands.flags import IfNotExistsOption, OverrideableOption from snowflake.cli.api.commands.snow_typer import SnowTyper from snowflake.cli.api.output.types import ( CommandResult, @@ -95,6 +93,7 @@ def create( query_warehouse: Optional[str] = QueryWarehouseOption(), tags: Optional[List[Tag]] = TagOption(help="Tag for the service."), comment: Optional[str] = CommentOption(help=_COMMENT_HELP), + if_not_exists: bool = IfNotExistsOption(), **options, ) -> CommandResult: """ @@ -114,6 +113,7 @@ def create( query_warehouse=query_warehouse, tags=tags, comment=comment, + if_not_exists=if_not_exists, ) return SingleQueryResult(cursor) diff --git a/src/snowflake/cli/plugins/spcs/services/manager.py b/src/snowflake/cli/plugins/spcs/services/manager.py index 5e006a6834..4f16df4ca4 100644 --- a/src/snowflake/cli/plugins/spcs/services/manager.py +++ b/src/snowflake/cli/plugins/spcs/services/manager.py @@ -27,11 +27,14 @@ def create( query_warehouse: Optional[str], tags: Optional[List[Tag]], comment: Optional[str], + if_not_exists: bool, ) -> SnowflakeCursor: spec = self._read_yaml(spec_path) - + create_statement = "CREATE SERVICE" + if if_not_exists: + create_statement = f"{create_statement} IF NOT EXISTS" query = f"""\ - CREATE SERVICE {service_name} + {create_statement} {service_name} IN COMPUTE POOL {compute_pool} FROM SPECIFICATION $$ {spec} diff --git a/src/snowflake/cli/plugins/streamlit/commands.py b/src/snowflake/cli/plugins/streamlit/commands.py index 2e713917d1..3f018ed8d2 100644 --- a/src/snowflake/cli/plugins/streamlit/commands.py +++ b/src/snowflake/cli/plugins/streamlit/commands.py @@ -1,6 +1,5 @@ import logging from pathlib import Path -from typing import Optional import click import typer @@ -10,6 +9,7 @@ with_experimental_behaviour, with_project_definition, ) +from snowflake.cli.api.commands.flags import ReplaceOption from snowflake.cli.api.commands.project_initialisation import add_init_command from snowflake.cli.api.commands.snow_typer import SnowTyper from snowflake.cli.api.output.types import ( @@ -17,6 +17,7 @@ MessageResult, SingleQueryResult, ) +from snowflake.cli.api.project.schemas.streamlit.streamlit import Streamlit from snowflake.cli.plugins.streamlit.manager import StreamlitManager app = SnowTyper( @@ -70,12 +71,7 @@ def _check_file_exists_if_not_default(ctx: click.Context, value): @with_project_definition("streamlit") @with_experimental_behaviour() def streamlit_deploy( - replace: Optional[bool] = typer.Option( - False, - "--replace", - help="Replace the Streamlit if it already exists.", - is_flag=True, - ), + replace: bool = ReplaceOption(help="Replace the Streamlit if it already exists."), open_: bool = typer.Option( False, "--open", help="Whether to open Streamlit in a browser.", is_flag=True ), @@ -86,31 +82,31 @@ def streamlit_deploy( upload environment.yml and pages/ folder if present. If stage name is not specified then 'streamlit' stage will be used. If stage does not exist it will be created by this command. """ - streamlit = cli_context.project_definition + streamlit: Streamlit = cli_context.project_definition if not streamlit: return MessageResult("No streamlit were specified in project definition.") - environment_file = streamlit.get("env_file", None) + environment_file = streamlit.env_file if environment_file and not Path(environment_file).exists(): raise ClickException(f"Provided file {environment_file} does not exist") elif environment_file is None: environment_file = "environment.yml" - pages_dir = streamlit.get("pages_dir", None) + pages_dir = streamlit.pages_dir if pages_dir and not Path(pages_dir).exists(): raise ClickException(f"Provided file {pages_dir} does not exist") elif pages_dir is None: pages_dir = "pages" url = StreamlitManager().deploy( - streamlit_name=streamlit["name"], + streamlit_name=streamlit.name, environment_file=Path(environment_file), pages_dir=Path(pages_dir), - stage_name=streamlit["stage"], - main_file=Path(streamlit["main_file"]), + stage_name=streamlit.stage, + main_file=Path(streamlit.main_file), replace=replace, - query_warehouse=streamlit["query_warehouse"], - additional_source_files=streamlit.get("additional_source_files"), + query_warehouse=streamlit.query_warehouse, + additional_source_files=streamlit.additional_source_files, **options, ) diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index 35d5ebd1c0..42f4795802 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -1,4 +1,33 @@ # serializer version: 1 +# name: test_help_messages[] + ''' + + Usage: default [OPTIONS] COMMAND [ARGS]... + + Snowflake CLI tool for developers. + + ╭─ Options ────────────────────────────────────────────────────────────────────╮ + │ --version Shows version of the Snowflake CLI │ + │ --info Shows information about the Snowflake CLI │ + │ --config-file FILE Specifies Snowflake CLI configuration file that │ + │ should be used │ + │ [default: None] │ + │ --help -h Show this message and exit. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ Commands ───────────────────────────────────────────────────────────────────╮ + │ app Manages a Snowflake Native App │ + │ connection Manages connections to Snowflake. │ + │ object Manages Snowflake objects like warehouses and stages │ + │ snowpark Manages procedures and functions. │ + │ spcs Manages Snowpark Container Services compute pools, services, │ + │ image registries, and image repositories. │ + │ sql Executes Snowflake query. │ + │ streamlit Manages Streamlit in Snowflake. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + + ''' +# --- # name: test_help_messages[app.bundle] ''' @@ -1888,6 +1917,11 @@ │ --comment TEXT Comment for the │ │ compute pool. │ │ [default: None] │ + │ --if-not-exists Only apply this │ + │ operation if the │ + │ specified object │ + │ does not already │ + │ exist. │ │ --help -h Show this │ │ message and │ │ exit. │ @@ -2609,7 +2643,10 @@ │ [required] │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Options ────────────────────────────────────────────────────────────────────╮ - │ --help -h Show this message and exit. │ + │ --replace Replace this object if it already exists. │ + │ --if-not-exists Only apply this operation if the specified object │ + │ does not already exist. │ + │ --help -h Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Connection configuration ───────────────────────────────────────────────────╮ │ --connection,--environment -c TEXT Name of the connection, as defined │ @@ -3160,6 +3197,11 @@ │ --comment TEXT Comment for the │ │ service. │ │ [default: None] │ + │ --if-not-exists Only apply this │ + │ operation if the │ + │ specified object │ + │ does not already │ + │ exist. │ │ --help -h Show this │ │ message and │ │ exit. │ diff --git a/tests/__snapshots__/test_snow_connector.ambr b/tests/__snapshots__/test_snow_connector.ambr index 9a4cf332e9..c923550308 100644 --- a/tests/__snapshots__/test_snow_connector.ambr +++ b/tests/__snapshots__/test_snow_connector.ambr @@ -19,7 +19,7 @@ use schema schemaValue; - create stage if not exists namedStageValue comment='deployments managed by snowcli'; + create stage if not exists namedStageValue comment='deployments managed by Snowflake CLI'; put file://file_pathValue @namedStageValuepathValue auto_compress=false parallel=4 overwrite=overwriteValue; @@ -45,7 +45,7 @@ use schema schemaValue; - create stage if not exists snow://embeddedStageValue comment='deployments managed by snowcli'; + create stage if not exists snow://embeddedStageValue comment='deployments managed by Snowflake CLI'; put file://file_pathValue snow://embeddedStageValuepathValue auto_compress=false parallel=4 overwrite=overwriteValue; diff --git a/tests/api/commands/__snapshots__/test_flags.ambr b/tests/api/commands/__snapshots__/test_flags.ambr new file mode 100644 index 0000000000..6ffec5bafb --- /dev/null +++ b/tests/api/commands/__snapshots__/test_flags.ambr @@ -0,0 +1,28 @@ +# serializer version: 1 +# name: test_format + ''' + Usage: default object stage list [OPTIONS] STAGE_NAME + Try 'default object stage list --help' for help. + ╭─ Error ──────────────────────────────────────────────────────────────────────╮ + │ Invalid value for '--format': 'invalid_format' is not one of 'TABLE', │ + │ 'JSON'. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + ''' +# --- +# name: test_mutually_exclusive_options_error + ''' + ╭─ Error ──────────────────────────────────────────────────────────────────────╮ + │ Options '--option2' and '--option1' are incompatible. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + ''' +# --- +# name: test_overrideable_option_callback_with_mutually_exclusive + ''' + ╭─ Error ──────────────────────────────────────────────────────────────────────╮ + │ Options '--option2' and '--option1' are incompatible. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + ''' +# --- diff --git a/tests/api/commands/test_flags.py b/tests/api/commands/test_flags.py index f2e1de910f..5c7405fc29 100644 --- a/tests/api/commands/test_flags.py +++ b/tests/api/commands/test_flags.py @@ -1,22 +1,25 @@ -from snowflake.cli.api.commands.flags import PLAIN_PASSWORD_MSG, PasswordOption +from unittest import mock +from unittest.mock import Mock, create_autospec, patch + +import click.core +import pytest +import typer +from snowflake.cli.api.commands.flags import ( + PLAIN_PASSWORD_MSG, + OverrideableOption, + PasswordOption, +) from typer import Typer +from typer.core import TyperOption from typer.testing import CliRunner -def test_format(runner): +def test_format(runner, snapshot): result = runner.invoke( ["object", "stage", "list", "stage_name", "--format", "invalid_format"] ) - assert result.output == ( - """Usage: default object stage list [OPTIONS] STAGE_NAME -Try 'default object stage list --help' for help. -╭─ Error ──────────────────────────────────────────────────────────────────────╮ -│ Invalid value for '--format': 'invalid_format' is not one of 'TABLE', │ -│ 'JSON'. │ -╰──────────────────────────────────────────────────────────────────────────────╯ -""" - ) + assert result.output == snapshot def test_password_flag(): @@ -30,3 +33,170 @@ def _(password: str = PasswordOption): result = runner.invoke(app, ["--password", "dummy"], catch_exceptions=False) assert result.exit_code == 0 assert PLAIN_PASSWORD_MSG in result.output + + +@patch("snowflake.cli.api.commands.flags.typer.Option") +def test_overrideable_option_returns_typer_option(mock_option): + mock_option_info = Mock(spec=typer.models.OptionInfo) + mock_option.return_value = mock_option_info + default = 1 + param_decls = ["--option"] + help_message = "help message" + + option = OverrideableOption(default, *param_decls, help=help_message)() + mock_option.assert_called_once_with(default, *param_decls, help=help_message) + assert option == mock_option_info + + +def test_overrideable_option_is_overrideable(): + original_param_decls = ("--option",) + original = OverrideableOption(1, *original_param_decls, help="original help") + + new_default = 2 + new_help = "new help" + modified = original(default=new_default, help=new_help) + + assert modified.default == new_default + assert modified.help == new_help + assert modified.param_decls == original_param_decls + + +_MUTEX_OPTION_1 = OverrideableOption( + False, "--option1", mutually_exclusive=["option_1", "option_2"] +) +_MUTEX_OPTION_2 = OverrideableOption( + False, "--option2", mutually_exclusive=["option_1", "option_2"] +) + + +@pytest.mark.parametrize("set1, set2", [(False, False), (False, True), (True, False)]) +def test_mutually_exclusive_options_no_error(set1, set2): + app = Typer() + + @app.command() + def _(option_1: bool = _MUTEX_OPTION_1(), option_2: bool = _MUTEX_OPTION_2()): + pass + + command = [] + if set1: + command.append("--option1") + if set2: + command.append("--option2") + runner = CliRunner() + result = runner.invoke(app, command) + assert result.exit_code == 0 + + +def test_mutually_exclusive_options_error(snapshot): + app = Typer() + + @app.command() + def _(option_1: bool = _MUTEX_OPTION_1(), option_2: bool = _MUTEX_OPTION_2()): + pass + + command = ["--option1", "--option2"] + runner = CliRunner() + result = runner.invoke(app, command) + assert result.exit_code == 1 + assert result.output == snapshot + + +def test_overrideable_option_callback_passthrough(): + def callback(value): + return value + 1 + + app = Typer() + + @app.command() + def _(option: int = OverrideableOption(..., "--option", callback=callback)()): + print(option) + + runner = CliRunner() + result = runner.invoke(app, ["--option", "0"]) + assert result.exit_code == 0 + assert result.output.strip() == "1" + + +def test_overrideable_option_callback_with_context(): + # tests that generated_callback will correctly map ctx and param arguments to the original callback + def callback(value, param: typer.CallbackParam, ctx: typer.Context): + assert isinstance(value, int) + assert isinstance(param, TyperOption) + assert isinstance(ctx, click.core.Context) + return value + + app = Typer() + + @app.command() + def _(option: int = OverrideableOption(..., "--option", callback=callback)()): + pass + + runner = CliRunner() + result = runner.invoke(app, ["--option", "0"]) + assert result.exit_code == 0 + + +class _InvalidCallbackSignatureNamespace: + # dummy functions for test_overrideable_option_invalid_callback_signature + + # too many parameters + @staticmethod + def callback1( + ctx: typer.Context, param: typer.CallbackParam, value1: int, value2: float + ): + pass + + # untyped Context and CallbackParam + @staticmethod + def callback2(ctx, param, value): + pass + + # multiple untyped values + @staticmethod + def callback3(ctx: typer.Context, value1, value2): + pass + + +@pytest.mark.parametrize( + "callback", + [ + _InvalidCallbackSignatureNamespace.callback1, + _InvalidCallbackSignatureNamespace.callback2, + _InvalidCallbackSignatureNamespace.callback3, + ], +) +def test_overrideable_option_invalid_callback_signature(callback): + invalid_callback_option = OverrideableOption(None, "--option", callback=callback) + with pytest.raises(OverrideableOption.InvalidCallbackSignature): + invalid_callback_option() + + +def test_overrideable_option_callback_with_mutually_exclusive(snapshot): + """ + Tests that is both 'callback' and 'mutually_exclusive' are passed to OverrideableOption, both are respected. This + is mainly for the rare use case where you are using 'mutually_exclusive' with non-flag options. + """ + + def passthrough(value): + return value + + mock_callback = create_autospec(passthrough) + app = Typer() + + @app.command() + def _( + option_1: int = _MUTEX_OPTION_1(default=None, callback=mock_callback), + option_2: int = _MUTEX_OPTION_2(default=None, callback=mock_callback), + ): + pass + + runner = CliRunner() + + # test that callback is called on the option values + runner.invoke(app, ["--option1", "1"]) + mock_callback.assert_has_calls([mock.call(value=1), mock.call(value=None)]) + + # test that we can't provide both options as non-falsey values without throwing error + result = runner.invoke(app, ["--option1", "1", "--option2", "2"]) + assert result.exit_code == 1 + assert result.output == snapshot diff --git a/tests/api/commands/test_snow_typer.py b/tests/api/commands/test_snow_typer.py index b0c7136762..08bb095dc3 100644 --- a/tests/api/commands/test_snow_typer.py +++ b/tests/api/commands/test_snow_typer.py @@ -147,21 +147,21 @@ def test_command_with_connection_options(cli, snapshot): assert result.output == snapshot -@mock.patch("snowflake.cli.api.commands.snow_typer.log_command_usage") +@mock.patch("snowflake.cli.app.telemetry.log_command_usage") def test_snow_typer_pre_execute_sends_telemetry(mock_log_command_usage, cli): result = cli(app_factory(SnowTyper))(["simple_cmd", "Norma"]) assert result.exit_code == 0 mock_log_command_usage.assert_called_once_with() -@mock.patch("snowflake.cli.api.commands.snow_typer.flush_telemetry") +@mock.patch("snowflake.cli.app.telemetry.flush_telemetry") def test_snow_typer_post_execute_sends_telemetry(mock_flush_telemetry, cli): result = cli(app_factory(SnowTyper))(["simple_cmd", "Norma"]) assert result.exit_code == 0 mock_flush_telemetry.assert_called_once_with() -@mock.patch("snowflake.cli.api.commands.snow_typer.print_result") +@mock.patch("snowflake.cli.app.printing.print_result") def test_snow_typer_result_callback_sends_telemetry(mock_print_result, cli): result = cli(app_factory(SnowTyper))(["simple_cmd", "Norma"]) assert result.exit_code == 0 diff --git a/tests/api/utils/test_naming_utils.py b/tests/api/utils/test_naming_utils.py new file mode 100644 index 0000000000..ea03f2cf78 --- /dev/null +++ b/tests/api/utils/test_naming_utils.py @@ -0,0 +1,15 @@ +import pytest +from snowflake.cli.api.utils.naming_utils import from_qualified_name + + +@pytest.mark.parametrize( + "qualified_name, expected", + [ + ("func(number, number)", ("func(number, number)", None, None)), + ("name", ("name", None, None)), + ("schema.name", ("name", "schema", None)), + ("db.schema.name", ("name", "schema", "db")), + ], +) +def test_from_fully_qualified_name(qualified_name, expected): + assert from_qualified_name(qualified_name) == expected diff --git a/tests/conftest.py b/tests/conftest.py index 1b07455cdd..c19e1610b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,6 @@ from snowflake.cli.api.console import cli_console from snowflake.cli.api.output.types import QueryResult from snowflake.cli.app import loggers -from snowflake.cli.app.cli_app import app pytest_plugins = ["tests.testing_utils.fixtures", "tests.project.fixtures"] @@ -72,7 +71,9 @@ def make_mock_cursor(mock_cursor): @pytest.fixture(name="faker_app") -def make_faker_app(_create_mock_cursor): +def make_faker_app(runner, _create_mock_cursor): + app = runner.app + @app.command("Faker") @with_output @global_options diff --git a/tests/nativeapp/test_artifacts.py b/tests/nativeapp/test_artifacts.py index 2f4bae98c8..51b311d8d6 100644 --- a/tests/nativeapp/test_artifacts.py +++ b/tests/nativeapp/test_artifacts.py @@ -39,10 +39,10 @@ def dir_structure(path: Path, prefix="") -> List[str]: @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) def test_napp_project_1_artifacts(project_definition_files): project_root = project_definition_files[0].parent - native_app = load_project_definition(project_definition_files)["native_app"] + native_app = load_project_definition(project_definition_files).native_app - deploy_root = Path(project_root, native_app["deploy_root"]) - artifacts = [translate_artifact(item) for item in native_app["artifacts"]] + deploy_root = Path(project_root, native_app.deploy_root) + artifacts = [translate_artifact(item) for item in native_app.artifacts] build_bundle(project_root, deploy_root, artifacts) assert dir_structure(deploy_root) == [ diff --git a/tests/nativeapp/test_manager.py b/tests/nativeapp/test_manager.py index 2477f7a714..58ee2f8e37 100644 --- a/tests/nativeapp/test_manager.py +++ b/tests/nativeapp/test_manager.py @@ -48,7 +48,7 @@ def _get_na_manager(): dm = DefinitionManager() return NativeAppManager( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/nativeapp/test_package_scripts.py b/tests/nativeapp/test_package_scripts.py index 24579b3610..77af1038a6 100644 --- a/tests/nativeapp/test_package_scripts.py +++ b/tests/nativeapp/test_package_scripts.py @@ -22,7 +22,7 @@ def _get_na_manager(working_dir): dm = DefinitionManager(working_dir) return NativeAppRunProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/nativeapp/test_run_processor.py b/tests/nativeapp/test_run_processor.py index 12f5b9c6b0..ee0fd55c9f 100644 --- a/tests/nativeapp/test_run_processor.py +++ b/tests/nativeapp/test_run_processor.py @@ -67,7 +67,7 @@ def _get_na_run_processor(): dm = DefinitionManager() return NativeAppRunProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) @@ -792,7 +792,7 @@ def test_create_dev_app_create_new_quoted( - setup.sql - app/README.md - src: app/streamlit/*.py - dest: ui/ + application: name: >- diff --git a/tests/nativeapp/test_teardown_processor.py b/tests/nativeapp/test_teardown_processor.py index c1b431c734..8d97e42443 100644 --- a/tests/nativeapp/test_teardown_processor.py +++ b/tests/nativeapp/test_teardown_processor.py @@ -38,7 +38,7 @@ def _get_na_teardown_processor(): dm = DefinitionManager() return NativeAppTeardownProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/nativeapp/test_version_create_processor.py b/tests/nativeapp/test_version_create_processor.py index d5c7f79302..e980803621 100644 --- a/tests/nativeapp/test_version_create_processor.py +++ b/tests/nativeapp/test_version_create_processor.py @@ -40,7 +40,7 @@ def _get_version_create_processor(): dm = DefinitionManager() return NativeAppVersionCreateProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/nativeapp/test_version_drop_processor.py b/tests/nativeapp/test_version_drop_processor.py index 846523a7e3..9d55191905 100644 --- a/tests/nativeapp/test_version_drop_processor.py +++ b/tests/nativeapp/test_version_drop_processor.py @@ -40,7 +40,7 @@ def _get_version_drop_processor(): dm = DefinitionManager() return NativeAppVersionDropProcessor( - project_definition=dm.project_definition["native_app"], + project_definition=dm.project_definition.native_app, project_root=dm.project_root, ) diff --git a/tests/project/test_config.py b/tests/project/test_config.py index 43684a7849..879a820557 100644 --- a/tests/project/test_config.py +++ b/tests/project/test_config.py @@ -4,29 +4,29 @@ from unittest.mock import PropertyMock import pytest +from pydantic import ValidationError from snowflake.cli.api.project.definition import ( generate_local_override_yml, load_project_definition, ) -from strictyaml import YAMLValidationError @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) def test_napp_project_1(project_definition_files): project = load_project_definition(project_definition_files) - assert project["native_app"]["name"] == "myapp" - assert project["native_app"]["deploy_root"] == "output/deploy/" - assert project["native_app"]["package"]["role"] == "accountadmin" - assert project["native_app"]["application"]["name"] == "myapp_polly" - assert project["native_app"]["application"]["role"] == "myapp_consumer" - assert project["native_app"]["application"]["debug"] == True + assert project.native_app.name == "myapp" + assert project.native_app.deploy_root == "output/deploy/" + assert project.native_app.package.role == "accountadmin" + assert project.native_app.application.name == "myapp_polly" + assert project.native_app.application.role == "myapp_consumer" + assert project.native_app.application.debug == True @pytest.mark.parametrize("project_definition_files", ["minimal"], indirect=True) def test_na_minimal_project(project_definition_files: List[Path]): project = load_project_definition(project_definition_files) - assert project["native_app"]["name"] == "minimal" - assert project["native_app"]["artifacts"] == ["setup.sql", "README.md"] + assert project.native_app.name == "minimal" + assert project.native_app.artifacts == ["setup.sql", "README.md"] from os import getenv as original_getenv @@ -46,36 +46,38 @@ def mock_getenv(key: str, default: Optional[str] = None) -> Optional[str]: # a definition structure for these values but directly return defaults # in "getter" functions (higher-level data structures). local = generate_local_override_yml(project) - assert local["native_app"]["application"]["name"] == "minimal_jsmith" - assert local["native_app"]["application"]["role"] == "resolved_role" - assert ( - local["native_app"]["application"]["warehouse"] == "resolved_warehouse" - ) - assert local["native_app"]["application"]["debug"] == True - assert local["native_app"]["package"]["name"] == "minimal_pkg_jsmith" - assert local["native_app"]["package"]["role"] == "resolved_role" + assert local.native_app.application.name == "minimal_jsmith" + assert local.native_app.application.role == "resolved_role" + assert local.native_app.application.warehouse == "resolved_warehouse" + assert local.native_app.application.debug == True + assert local.native_app.package.name == "minimal_pkg_jsmith" + assert local.native_app.package.role == "resolved_role" @pytest.mark.parametrize("project_definition_files", ["underspecified"], indirect=True) def test_underspecified_project(project_definition_files): - with pytest.raises(YAMLValidationError) as exc_info: + with pytest.raises(ValidationError) as exc_info: load_project_definition(project_definition_files) - assert "required key(s) 'artifacts' not found" in str(exc_info.value) + assert ( + "Field required [type=missing, input_value={'name': 'underspecified'}, input_type=dict]" + in str(exc_info.value) + ) @pytest.mark.parametrize( "project_definition_files", ["no_definition_version"], indirect=True ) def test_fails_without_definition_version(project_definition_files): - with pytest.raises(YAMLValidationError) as exc_info: + with pytest.raises(ValidationError) as exc_info: load_project_definition(project_definition_files) - assert "required key(s) 'definition_version' not found" in str(exc_info.value) + assert "definition_version" in str(exc_info.value) @pytest.mark.parametrize("project_definition_files", ["unknown_fields"], indirect=True) -def test_accepts_unknown_fields(project_definition_files): - project = load_project_definition(project_definition_files) - assert project["native_app"]["name"] == "unknown_fields" - assert project["native_app"]["unknown_fields_accepted"] == True +def test_does_not_accept_unknown_fields(project_definition_files): + with pytest.raises(ValidationError) as e: + project = load_project_definition(project_definition_files) + + assert "Extra inputs are not permitted [type=extra_forbidden" in e.value.__str__() diff --git a/tests/project/test_pydantic_schemas.py b/tests/project/test_pydantic_schemas.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/snowpark/test_function.py b/tests/snowpark/test_function.py index 231d0f655f..75da0b0dd3 100644 --- a/tests/snowpark/test_function.py +++ b/tests/snowpark/test_function.py @@ -31,7 +31,7 @@ def test_deploy_function( assert result.exit_code == 0, result.output assert ctx.get_queries() == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project" f" auto_compress=false parallel=4 overwrite=True", dedent( @@ -78,7 +78,7 @@ def test_deploy_function_with_external_access( assert result.exit_code == 0, result.output assert ctx.get_queries() == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project" f" auto_compress=false parallel=4 overwrite=True", dedent( @@ -159,7 +159,7 @@ def test_deploy_function_no_changes( } ] assert queries == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project auto_compress=false parallel=4 overwrite=True", ] @@ -197,7 +197,7 @@ def test_deploy_function_needs_update_because_packages_changes( } ] assert queries == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project auto_compress=false parallel=4 overwrite=True", dedent( """\ @@ -246,7 +246,7 @@ def test_deploy_function_needs_update_because_handler_changes( } ] assert queries == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project" f" auto_compress=false parallel=4 overwrite=True", dedent( diff --git a/tests/snowpark/test_procedure.py b/tests/snowpark/test_procedure.py index efdd2b7e3c..b83f709f56 100644 --- a/tests/snowpark/test_procedure.py +++ b/tests/snowpark/test_procedure.py @@ -52,7 +52,7 @@ def test_deploy_procedure( ] ) assert ctx.get_queries() == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(tmp).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project auto_compress=false parallel=4 overwrite=True", dedent( """\ @@ -117,7 +117,7 @@ def test_deploy_procedure_with_external_access( ] ) assert ctx.get_queries() == [ - "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by snowcli'", + "create stage if not exists MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT comment='deployments managed by Snowflake CLI'", f"put file://{Path(project_dir).resolve()}/app.zip @MOCKDATABASE.MOCKSCHEMA.DEV_DEPLOYMENT/my_snowpark_project" f" auto_compress=false parallel=4 overwrite=True", dedent( diff --git a/tests/spcs/__snapshots__/test_image_repository.ambr b/tests/spcs/__snapshots__/test_image_repository.ambr new file mode 100644 index 0000000000..a7de666b3e --- /dev/null +++ b/tests/spcs/__snapshots__/test_image_repository.ambr @@ -0,0 +1,19 @@ +# serializer version: 1 +# name: test_create_cli + ''' + +-----------------------------------------------------------+ + | key | value | + |--------+--------------------------------------------------| + | status | Image Repository TEST_REPO successfully created. | + +-----------------------------------------------------------+ + + ''' +# --- +# name: test_create_cli_replace_and_if_not_exists_fails + ''' + ╭─ Error ──────────────────────────────────────────────────────────────────────╮ + │ Options '--if-not-exists' and '--replace' are incompatible. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + ''' +# --- diff --git a/tests/spcs/test_compute_pool.py b/tests/spcs/test_compute_pool.py index 440a0278a7..63104becf3 100644 --- a/tests/spcs/test_compute_pool.py +++ b/tests/spcs/test_compute_pool.py @@ -41,6 +41,7 @@ def test_create(mock_execute_query): initially_suspended=initially_suspended, auto_suspend_secs=auto_suspend_secs, comment=comment, + if_not_exists=False, ) expected_query = " ".join( [ @@ -81,6 +82,7 @@ def test_create_pool_cli_defaults(mock_create, runner): initially_suspended=False, auto_suspend_secs=3600, comment=None, + if_not_exists=False, ) @@ -104,6 +106,7 @@ def test_create_pool_cli(mock_create, runner): "7200", "--comment", "this is a test", + "--if-not-exists", ] ) assert result.exit_code == 0, result.output @@ -116,6 +119,7 @@ def test_create_pool_cli(mock_create, runner): initially_suspended=True, auto_suspend_secs=7200, comment=to_string_literal("this is a test"), + if_not_exists=True, ) @@ -123,8 +127,8 @@ def test_create_pool_cli(mock_create, runner): "snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager._execute_query" ) @patch("snowflake.cli.plugins.spcs.compute_pool.manager.handle_object_already_exists") -def test_create_repository_already_exists(mock_handle, mock_execute): - pool_name = "test_object" +def test_create_compute_pool_already_exists(mock_handle, mock_execute): + pool_name = "test_pool" mock_execute.side_effect = SPCS_OBJECT_EXISTS_ERROR ComputePoolManager().create( pool_name=pool_name, @@ -135,12 +139,46 @@ def test_create_repository_already_exists(mock_handle, mock_execute): initially_suspended=True, auto_suspend_secs=7200, comment=to_string_literal("this is a test"), + if_not_exists=False, ) mock_handle.assert_called_once_with( SPCS_OBJECT_EXISTS_ERROR, ObjectType.COMPUTE_POOL, pool_name ) +@patch( + "snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager._execute_query" +) +def test_create_compute_pool_if_not_exists(mock_execute_query): + cursor = Mock(spec=SnowflakeCursor) + mock_execute_query.return_value = cursor + result = ComputePoolManager().create( + pool_name="test_pool", + min_nodes=1, + max_nodes=1, + instance_family="test_family", + auto_resume=True, + initially_suspended=False, + auto_suspend_secs=3600, + comment=None, + if_not_exists=True, + ) + expected_query = " ".join( + [ + "CREATE COMPUTE POOL IF NOT EXISTS test_pool", + "MIN_NODES = 1", + "MAX_NODES = 1", + "INSTANCE_FAMILY = test_family", + "AUTO_RESUME = True", + "INITIALLY_SUSPENDED = False", + "AUTO_SUSPEND_SECS = 3600", + ] + ) + actual_query = " ".join(mock_execute_query.mock_calls[0].args[0].split()) + assert expected_query == actual_query + assert result == cursor + + @patch( "snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager._execute_query" ) diff --git a/tests/spcs/test_image_repository.py b/tests/spcs/test_image_repository.py index 9332c8379a..21b140f297 100644 --- a/tests/spcs/test_image_repository.py +++ b/tests/spcs/test_image_repository.py @@ -4,7 +4,6 @@ from unittest.mock import Mock import pytest -from click import ClickException from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.exceptions import ( DatabaseNotProvidedError, @@ -12,6 +11,7 @@ ) from snowflake.cli.plugins.spcs.image_repository.manager import ImageRepositoryManager from snowflake.connector.cursor import SnowflakeCursor +from snowflake.connector.errors import ProgrammingError from tests.spcs.test_common import SPCS_OBJECT_EXISTS_ERROR @@ -44,37 +44,68 @@ ] +@pytest.mark.parametrize( + "replace, if_not_exists, expected_query", + [ + (False, False, "create image repository test_repo"), + (False, True, "create image repository if not exists test_repo"), + (True, False, "create or replace image repository test_repo"), + # (True, True) is an invalid case as OR REPLACE and IF NOT EXISTS are mutually exclusive. + ], +) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._execute_schema_query" ) -def test_create( - mock_execute, -): +def test_create(mock_execute, replace, if_not_exists, expected_query): repo_name = "test_repo" cursor = Mock(spec=SnowflakeCursor) mock_execute.return_value = cursor - result = ImageRepositoryManager().create(name=repo_name) - expected_query = "create image repository test_repo" - mock_execute.assert_called_once_with(expected_query) + result = ImageRepositoryManager().create( + name=repo_name, replace=replace, if_not_exists=if_not_exists + ) + mock_execute.assert_called_once_with(expected_query, name=repo_name) assert result == cursor +def test_create_replace_and_if_not_exist(): + with pytest.raises(ValueError) as e: + ImageRepositoryManager().create( + name="test_repo", replace=True, if_not_exists=True + ) + assert "mutually exclusive" in str(e.value) + + @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.create" ) -def test_create_cli(mock_create, mock_cursor, runner): +def test_create_cli(mock_create, mock_cursor, runner, snapshot): repo_name = "test_repo" cursor = mock_cursor( rows=[[f"Image Repository {repo_name.upper()} successfully created."]], columns=["status"], ) mock_create.return_value = cursor - result = runner.invoke(["spcs", "image-repository", "create", repo_name]) - mock_create.assert_called_once_with(name=repo_name) - assert result.exit_code == 0, result.output - assert ( - f"Image Repository {repo_name.upper()} successfully created." in result.output + command = ["spcs", "image-repository", "create", repo_name] + result = runner.invoke(command) + mock_create.assert_called_once_with( + name=repo_name, replace=False, if_not_exists=False ) + assert result.exit_code == 0, result.output + assert result.output == snapshot + + +def test_create_cli_replace_and_if_not_exists_fails(runner, snapshot): + command = [ + "spcs", + "image-repository", + "create", + "test_repo", + "--replace", + "--if-not-exists", + ] + result = runner.invoke(command) + assert result.exit_code == 1 + assert result.output == snapshot @mock.patch( @@ -86,9 +117,12 @@ def test_create_cli(mock_create, mock_cursor, runner): def test_create_repository_already_exists(mock_handle, mock_execute): repo_name = "test_object" mock_execute.side_effect = SPCS_OBJECT_EXISTS_ERROR - ImageRepositoryManager().create(repo_name) + ImageRepositoryManager().create(repo_name, replace=False, if_not_exists=False) mock_handle.assert_called_once_with( - SPCS_OBJECT_EXISTS_ERROR, ObjectType.IMAGE_REPOSITORY, repo_name + SPCS_OBJECT_EXISTS_ERROR, + ObjectType.IMAGE_REPOSITORY, + repo_name, + replace_available=True, ) @@ -191,13 +225,10 @@ def test_get_repository_url_cli(mock_url, runner): assert result.output.strip() == repo_url -@mock.patch( - "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.check_database_and_schema" -) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.show_specific_object" ) -def test_get_repository_url(mock_get_row, mock_check_database_and_schema): +def test_get_repository_url(mock_get_row): expected_row = MOCK_ROWS_DICT[0] mock_get_row.return_value = expected_row result = ImageRepositoryManager().get_repository_url(repo_name="IMAGES") @@ -209,13 +240,10 @@ def test_get_repository_url(mock_get_row, mock_check_database_and_schema): assert result == f"https://{expected_row['repository_url']}" -@mock.patch( - "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.check_database_and_schema" -) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.show_specific_object" ) -def test_get_repository_url_no_scheme(mock_get_row, mock_check_database_and_schema): +def test_get_repository_url_no_scheme(mock_get_row): expected_row = MOCK_ROWS_DICT[0] mock_get_row.return_value = expected_row result = ImageRepositoryManager().get_repository_url( @@ -229,26 +257,21 @@ def test_get_repository_url_no_scheme(mock_get_row, mock_check_database_and_sche assert result == expected_row["repository_url"] -@mock.patch( - "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.check_database_and_schema" -) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._conn" ) @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.show_specific_object" ) -def test_get_repository_url_no_repo_found( - mock_get_row, mock_conn, mock_check_database_and_schema -): +def test_get_repository_url_no_repo_found(mock_get_row, mock_conn): mock_get_row.return_value = None mock_conn.database = "DB" mock_conn.schema = "SCHEMA" - with pytest.raises(ClickException) as e: + with pytest.raises(ProgrammingError) as e: ImageRepositoryManager().get_repository_url(repo_name="IMAGES") assert ( - e.value.message - == "Image repository 'IMAGES' does not exist in database 'DB' and schema 'SCHEMA' or not authorized." + e.value.msg + == "Image repository 'DB.SCHEMA.IMAGES' does not exist or not authorized." ) mock_get_row.assert_called_once_with( "image repositories", "IMAGES", check_schema=True @@ -258,17 +281,17 @@ def test_get_repository_url_no_repo_found( @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._conn" ) -def test_get_repository_url_no_database(mock_conn): +def test_get_repository_url_no_database_provided(mock_conn): mock_conn.database = None with pytest.raises(DatabaseNotProvidedError): - ImageRepositoryManager().get_repository_url("test_repo") + ImageRepositoryManager().get_repository_url("IMAGES") @mock.patch( "snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager._conn" ) -@mock.patch("snowflake.cli.api.sql_execution.SqlExecutionMixin.check_database_exists") -def test_get_repository_url_no_schema(mock_check_database_exists, mock_conn): +def test_get_repository_url_no_schema_provided(mock_conn): + mock_conn.database = "DB" mock_conn.schema = None with pytest.raises(SchemaNotProvidedError): - ImageRepositoryManager().get_repository_url("test_repo") + ImageRepositoryManager().get_repository_url("IMAGES") diff --git a/tests/spcs/test_jobs.py b/tests/spcs/test_jobs.py index b714fc2f8e..4658a03979 100644 --- a/tests/spcs/test_jobs.py +++ b/tests/spcs/test_jobs.py @@ -2,7 +2,10 @@ from tempfile import TemporaryDirectory from unittest import mock +import pytest + +@pytest.mark.skip("Snowpark Container Services Job not supported.") @mock.patch("snowflake.connector.connect") def test_create_job(mock_connector, runner, mock_ctx): ctx = mock_ctx() @@ -40,6 +43,7 @@ def test_create_job(mock_connector, runner, mock_ctx): ) +@pytest.mark.skip("Snowpark Container Services Job not supported.") @mock.patch("snowflake.connector.connect") def test_job_status(mock_connector, runner, mock_ctx): ctx = mock_ctx() @@ -51,6 +55,7 @@ def test_job_status(mock_connector, runner, mock_ctx): assert ctx.get_query() == "CALL SYSTEM$GET_JOB_STATUS('jobName')" +@pytest.mark.skip("Snowpark Container Services Job not supported.") @mock.patch("snowflake.connector.connect") def test_job_logs(mock_connector, runner, mock_ctx): ctx = mock_ctx() diff --git a/tests/spcs/test_services.py b/tests/spcs/test_services.py index 24919a3dff..e7ee202e75 100644 --- a/tests/spcs/test_services.py +++ b/tests/spcs/test_services.py @@ -72,6 +72,7 @@ def test_create_service(mock_execute_query, other_directory): query_warehouse=query_warehouse, tags=tags, comment=comment, + if_not_exists=False, ) expected_query = " ".join( [ @@ -120,6 +121,7 @@ def test_create_service_cli_defaults(mock_create, other_directory, runner): query_warehouse=None, tags=[], comment=None, + if_not_exists=False, ) @@ -155,6 +157,7 @@ def test_create_service_cli(mock_create, other_directory, runner): '"$trange name"=normal value', "--comment", "this is a test", + "--if-not-exists", ] ) assert result.exit_code == 0, result.output @@ -169,6 +172,7 @@ def test_create_service_cli(mock_create, other_directory, runner): query_warehouse="test_warehouse", tags=[Tag("name", "value"), Tag('"$trange name"', "normal value")], comment=to_string_literal("this is a test"), + if_not_exists=True, ) @@ -195,14 +199,15 @@ def test_create_service_with_invalid_spec(mock_read_yaml): query_warehouse=query_warehouse, tags=tags, comment=comment, + if_not_exists=False, ) @patch("snowflake.cli.plugins.spcs.services.manager.ServiceManager._read_yaml") @patch("snowflake.cli.plugins.spcs.services.manager.ServiceManager._execute_query") @patch("snowflake.cli.plugins.spcs.services.manager.handle_object_already_exists") -def test_create_repository_already_exists(mock_handle, mock_execute, mock_read_yaml): - service_name = "test_object" +def test_create_service_already_exists(mock_handle, mock_execute, mock_read_yaml): + service_name = "test_service" compute_pool = "test_pool" spec_path = "/path/to/spec.yaml" min_instances = 42 @@ -221,12 +226,47 @@ def test_create_repository_already_exists(mock_handle, mock_execute, mock_read_y query_warehouse=query_warehouse, tags=tags, comment=comment, + if_not_exists=False, ) mock_handle.assert_called_once_with( SPCS_OBJECT_EXISTS_ERROR, ObjectType.SERVICE, service_name ) +@patch("snowflake.cli.plugins.spcs.services.manager.ServiceManager._execute_query") +def test_create_service_if_not_exists(mock_execute_query, other_directory): + cursor = Mock(spec=SnowflakeCursor) + mock_execute_query.return_value = cursor + tmp_dir = Path(other_directory) + spec_path = tmp_dir / "spec.yml" + spec_path.write_text(SPEC_CONTENT) + result = ServiceManager().create( + service_name="test_service", + compute_pool="test_pool", + spec_path=spec_path, + min_instances=1, + max_instances=1, + auto_resume=True, + external_access_integrations=None, + query_warehouse=None, + tags=None, + comment=None, + if_not_exists=True, + ) + expected_query = " ".join( + [ + "CREATE SERVICE IF NOT EXISTS test_service", + "IN COMPUTE POOL test_pool", + f"FROM SPECIFICATION $$ {json.dumps(SPEC_DICT)} $$", + "MIN_INSTANCES = 1 MAX_INSTANCES = 1", + "AUTO_RESUME = True", + ] + ) + actual_query = " ".join(mock_execute_query.mock_calls[0].args[0].split()) + assert expected_query == actual_query + assert result == cursor + + @patch("snowflake.cli.plugins.spcs.services.manager.ServiceManager._execute_query") def test_status(mock_execute_query): service_name = "test_service" diff --git a/tests/streamlit/test_config.py b/tests/streamlit/test_config.py index a1253f173e..86d51c1b7a 100644 --- a/tests/streamlit/test_config.py +++ b/tests/streamlit/test_config.py @@ -30,4 +30,4 @@ def test_load_project_definition(test_files, expected): result = load_project_definition(test_files) - assert expected in result["streamlit"]["additional_source_files"] + assert expected in result.streamlit.additional_source_files diff --git a/tests/test_help_messages.py b/tests/test_help_messages.py index fcafaae5b3..477679b622 100644 --- a/tests/test_help_messages.py +++ b/tests/test_help_messages.py @@ -18,6 +18,7 @@ def _iter_through_commands(command, path): yield from _iter_through_commands(subcommand, path) path.pop() + yield [] # "snow" with no commands builtin_plugins = load_only_builtin_command_plugins() for plugin in builtin_plugins: spec = plugin.command_spec diff --git a/tests/test_sql.py b/tests/test_sql.py index 95a70b24e4..f433dc8bec 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -4,8 +4,10 @@ import pytest from snowflake.cli.api.exceptions import SnowflakeSQLExecutionError +from snowflake.cli.api.project.util import identifier_to_show_like_pattern from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.connector.cursor import DictCursor +from snowflake.connector.errors import ProgrammingError from tests.testing_utils.result_assertions import assert_that_result_is_usage_error @@ -169,3 +171,80 @@ def test_show_specific_object_sql_execution_error(mock_execute): mock_execute.assert_called_once_with( r"show objects like 'EXAMPLE\\_ID'", cursor_class=DictCursor ) + + +@pytest.mark.parametrize( + "name, name_split, expected_name, expected_in_clause", + [ + ( + "func(number, number)", + ("func(number, number)", None, None), + "func(number, number)", + None, + ), + ("name", ("name", None, None), "name", None), + ("schema.name", ("name", "schema", None), "name", "in schema schema"), + ("db.schema.name", ("name", "schema", "db"), "name", "in schema db.schema"), + ], +) +@mock.patch("snowflake.cli.api.sql_execution.from_qualified_name") +def test_qualified_name_to_in_clause( + mock_from_qualified_name, name, name_split, expected_name, expected_in_clause +): + mock_from_qualified_name.return_value = name_split + assert SqlExecutionMixin._qualified_name_to_in_clause(name) == ( # noqa: SLF001 + expected_name, + expected_in_clause, + ) + mock_from_qualified_name.assert_called_once_with(name) + + +@mock.patch("snowflake.cli.plugins.sql.manager.SqlExecutionMixin._execute_query") +@mock.patch( + "snowflake.cli.api.sql_execution.SqlExecutionMixin._qualified_name_to_in_clause" +) +def test_show_specific_object_qualified_name( + mock_qualified_name_to_in_clause, mock_execute_query, mock_cursor +): + name = "db.schema.obj" + unqualified_name = "obj" + name_in_clause = "in schema db.schema" + mock_columns = ["name", "created_on"] + mock_row_dict = {c: r for c, r in zip(mock_columns, [unqualified_name, "date"])} + cursor = mock_cursor(rows=[mock_row_dict], columns=mock_columns) + mock_execute_query.return_value = cursor + + mock_qualified_name_to_in_clause.return_value = (unqualified_name, name_in_clause) + SqlExecutionMixin().show_specific_object("objects", name) + mock_execute_query.assert_called_once_with( + f"show objects like {identifier_to_show_like_pattern(unqualified_name)} {name_in_clause}", + cursor_class=DictCursor, + ) + + +@mock.patch( + "snowflake.cli.api.sql_execution.SqlExecutionMixin._qualified_name_to_in_clause" +) +def test_show_specific_object_qualified_name_and_in_clause_error( + mock_qualified_name_to_in_clause, +): + object_name = "db.schema.name" + mock_qualified_name_to_in_clause.return_value = ("name", "in schema db.schema") + with pytest.raises(SqlExecutionMixin.InClauseWithQualifiedNameError): + SqlExecutionMixin().show_specific_object( + "objects", object_name, in_clause="in database db" + ) + mock_qualified_name_to_in_clause.assert_called_once_with(object_name) + + +@mock.patch("snowflake.cli.api.sql_execution.SqlExecutionMixin._execute_query") +def test_show_specific_object_multiple_rows(mock_execute_query): + cursor = mock.Mock(spec=DictCursor) + cursor.rowcount = 2 + mock_execute_query.return_value = cursor + with pytest.raises(ProgrammingError) as err: + SqlExecutionMixin().show_specific_object("objects", "name", name_col="id") + assert "Received multiple rows" in err.value.msg + mock_execute_query.assert_called_once_with( + r"show objects like 'NAME'", cursor_class=DictCursor + ) diff --git a/tests/testing_utils/fixtures.py b/tests/testing_utils/fixtures.py index fe643b5b59..5c30a0bb23 100644 --- a/tests/testing_utils/fixtures.py +++ b/tests/testing_utils/fixtures.py @@ -13,6 +13,7 @@ import pytest import strictyaml from snowflake.cli.api.project.definition import merge_left +from snowflake.cli.app.cli_app import app_factory from snowflake.connector.cursor import SnowflakeCursor from snowflake.connector.errors import ProgrammingError from strictyaml import as_document @@ -80,7 +81,7 @@ def dot_packages_directory(temp_dir): @pytest.fixture() def mock_ctx(mock_cursor): - return lambda cursor=mock_cursor(["row"], []): MockConnectionCtx(cursor) + yield lambda cursor=mock_cursor(["row"], []): MockConnectionCtx(cursor) class MockConnectionCtx(mock.MagicMock): @@ -200,9 +201,8 @@ def package_file(): @pytest.fixture(scope="function") def runner(test_snowcli_config): - from snowflake.cli.app.cli_app import app - - return SnowCLIRunner(app, test_snowcli_config) + app = app_factory() + yield SnowCLIRunner(app, test_snowcli_config) @pytest.fixture diff --git a/tests_integration/conftest.py b/tests_integration/conftest.py index 404b4eba87..633c7d4046 100644 --- a/tests_integration/conftest.py +++ b/tests_integration/conftest.py @@ -14,7 +14,7 @@ import strictyaml from snowflake.cli.api.cli_global_context import cli_context_manager from snowflake.cli.api.project.definition import merge_left -from snowflake.cli.app.cli_app import app +from snowflake.cli.app.cli_app import app_factory from strictyaml import as_document from typer import Typer from typer.testing import CliRunner @@ -113,7 +113,8 @@ def invoke_with_connection( @pytest.fixture def runner(test_snowcli_config_provider): - return SnowCLIRunner(app, test_snowcli_config_provider) + app = app_factory() + yield SnowCLIRunner(app, test_snowcli_config_provider) class QueryResultJsonEncoderError(RuntimeError): diff --git a/tests_integration/test_external_plugins.py b/tests_integration/test_external_plugins.py index c7112a36c0..cdbc88d83b 100644 --- a/tests_integration/test_external_plugins.py +++ b/tests_integration/test_external_plugins.py @@ -93,7 +93,10 @@ def test_loading_of_installed_plugins_if_all_plugins_enabled( @pytest.mark.integration def test_loading_of_installed_plugins_if_only_one_plugin_is_enabled( - runner, install_plugins, caplog, reset_command_registration_state + runner, + install_plugins, + caplog, + reset_command_registration_state, ): runner.use_config("config_with_enabled_only_one_external_plugin.toml") @@ -111,8 +114,18 @@ def test_loading_of_installed_plugins_if_only_one_plugin_is_enabled( @pytest.mark.integration +@pytest.mark.parametrize( + "config_value", + ( + pytest.param("1", id="integer as value"), + pytest.param('"True"', id="string as value"), + ), +) def test_enabled_value_must_be_boolean( - runner, snowflake_home, reset_command_registration_state + config_value, + runner, + snowflake_home, + reset_command_registration_state, ): def _use_config_with_value(value): config = Path(snowflake_home) / "config.toml" @@ -123,19 +136,18 @@ def _use_config_with_value(value): ) runner.use_config(config) - for value in ["1", '"True"']: - _use_config_with_value(value) - result = runner.invoke_with_config(["--help"]) - output = result.output.splitlines() - assert all( - [ - "Error" in output[0], - 'Invalid plugin configuration. [multilingual-hello]: "enabled" must be a' - in output[1], - "boolean" in output[2], - ] - ) - reset_command_registration_state() + _use_config_with_value(config_value) + result = runner.invoke_with_config(("--help,")) + + first, second, third, *_ = result.output.splitlines() + assert "Error" in first, first + assert ( + 'Invalid plugin configuration. [multilingual-hello]: "enabled" must be a' + in second + ), second + assert "boolean" in third, third + + reset_command_registration_state() def _assert_that_no_error_logs(caplog): From 6f8930e8da95740ec13218c3c3b074a173312760 Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Wed, 6 Mar 2024 15:42:51 +0100 Subject: [PATCH 02/16] Fixes after review --- src/snowflake/cli/api/project/errors.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 src/snowflake/cli/api/project/errors.py diff --git a/src/snowflake/cli/api/project/errors.py b/src/snowflake/cli/api/project/errors.py new file mode 100644 index 0000000000..4639ef0f50 --- /dev/null +++ b/src/snowflake/cli/api/project/errors.py @@ -0,0 +1,3 @@ + +class SchemaValidationError(Exception): + pass \ No newline at end of file From 67dd5f5142f78ecd841d02b14a0f663840ca766b Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Wed, 6 Mar 2024 17:03:47 +0100 Subject: [PATCH 03/16] Implemented error class --- src/snowflake/cli/api/project/errors.py | 23 +++++++++++++++- .../api/project/schemas/updatable_model.py | 9 ++++++- tests/project/test_config.py | 26 ++++++++++++------- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/snowflake/cli/api/project/errors.py b/src/snowflake/cli/api/project/errors.py index 4639ef0f50..ba6844b510 100644 --- a/src/snowflake/cli/api/project/errors.py +++ b/src/snowflake/cli/api/project/errors.py @@ -1,3 +1,24 @@ +from textwrap import dedent + +from pydantic import ValidationError + class SchemaValidationError(Exception): - pass \ No newline at end of file + generic_message = "For field {loc} you provided '{loc}'. This caused: {msg}" + message_templates = { + "string_type": "{msg} for field '{loc}', you provided '{input}'", + "extra_forbidden": "{msg}. You provided field '{loc}' with value '{input}' that is not present in the schema", + "missing": "Your project definition is missing following fields: {loc}", + } + + def __init__(self, error: ValidationError): + errors = error.errors() + message = f"During evaluation of {error.title} schema following errors were encoutered:\n" + message += "\n".join( + [ + self.message_templates.get(e["type"], self.generic_message).format(**e) + for e in errors + ] + ) + + super().__init__(dedent(message)) diff --git a/src/snowflake/cli/api/project/schemas/updatable_model.py b/src/snowflake/cli/api/project/schemas/updatable_model.py index 12d8551a9a..837e625d53 100644 --- a/src/snowflake/cli/api/project/schemas/updatable_model.py +++ b/src/snowflake/cli/api/project/schemas/updatable_model.py @@ -1,12 +1,19 @@ from typing import Any, Dict -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from snowflake.cli.api.project.errors import SchemaValidationError from snowflake.cli.api.project.util import IDENTIFIER_NO_LENGTH class UpdatableModel(BaseModel): model_config = ConfigDict(validate_assignment=True, extra="forbid") + def __init__(self, *args, **kwargs): + try: + super().__init__(**kwargs) + except ValidationError as e: + raise SchemaValidationError(e) + def update_from_dict( self, update_values: Dict[str, Any] ): # this method works wrong for optional fields set to None diff --git a/tests/project/test_config.py b/tests/project/test_config.py index 879a820557..4d06803ed1 100644 --- a/tests/project/test_config.py +++ b/tests/project/test_config.py @@ -4,11 +4,11 @@ from unittest.mock import PropertyMock import pytest -from pydantic import ValidationError from snowflake.cli.api.project.definition import ( generate_local_override_yml, load_project_definition, ) +from snowflake.cli.api.project.errors import SchemaValidationError @pytest.mark.parametrize("project_definition_files", ["napp_project_1"], indirect=True) @@ -56,12 +56,12 @@ def mock_getenv(key: str, default: Optional[str] = None) -> Optional[str]: @pytest.mark.parametrize("project_definition_files", ["underspecified"], indirect=True) def test_underspecified_project(project_definition_files): - with pytest.raises(ValidationError) as exc_info: + with pytest.raises(SchemaValidationError) as exc_info: load_project_definition(project_definition_files) - assert ( - "Field required [type=missing, input_value={'name': 'underspecified'}, input_type=dict]" - in str(exc_info.value) + assert "NativeApp schema" in str(exc_info) + assert "Your project definition is missing following fields: ('artifacts',)" in str( + exc_info.value ) @@ -69,15 +69,23 @@ def test_underspecified_project(project_definition_files): "project_definition_files", ["no_definition_version"], indirect=True ) def test_fails_without_definition_version(project_definition_files): - with pytest.raises(ValidationError) as exc_info: + with pytest.raises(SchemaValidationError) as exc_info: load_project_definition(project_definition_files) - assert "definition_version" in str(exc_info.value) + assert "ProjectDefinition" in str(exc_info) + assert ( + "Your project definition is missing following fields: ('definition_version',)" + in str(exc_info.value) + ) @pytest.mark.parametrize("project_definition_files", ["unknown_fields"], indirect=True) def test_does_not_accept_unknown_fields(project_definition_files): - with pytest.raises(ValidationError) as e: + with pytest.raises(SchemaValidationError) as exc_info: project = load_project_definition(project_definition_files) - assert "Extra inputs are not permitted [type=extra_forbidden" in e.value.__str__() + assert "NativeApp schema" in str(exc_info) + assert ( + "You provided field '('unknown_fields_accepted',)' with value 'true' that is not present in the schema" + in str(exc_info) + ) From f46a54bca06d3119d7369be70c3df1036573448a Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Wed, 6 Mar 2024 17:12:54 +0100 Subject: [PATCH 04/16] Fixes --- src/snowflake/cli/api/project/schemas/snowpark/callable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/cli/api/project/schemas/snowpark/callable.py b/src/snowflake/cli/api/project/schemas/snowpark/callable.py index ccc36995bc..6da1e235de 100644 --- a/src/snowflake/cli/api/project/schemas/snowpark/callable.py +++ b/src/snowflake/cli/api/project/schemas/snowpark/callable.py @@ -31,7 +31,7 @@ class Callable(UpdatableModel): signature: Union[str, List[Argument]] = Field( title="The signature parameter describes consecutive arguments passed to the object" ) - runtime: Optional[str | float] = Field( + runtime: Optional[Union[str, float]] = Field( title="Python version to use when executing ", default=None ) external_access_integrations: Optional[List[str]] = Field( From 04459c85d43cd9b0e694c2acd04fcebfe3d4754a Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Wed, 6 Mar 2024 17:19:19 +0100 Subject: [PATCH 05/16] Fixes --- src/snowflake/cli/api/project/schemas/snowpark/callable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/cli/api/project/schemas/snowpark/callable.py b/src/snowflake/cli/api/project/schemas/snowpark/callable.py index 6da1e235de..f7a5ceae18 100644 --- a/src/snowflake/cli/api/project/schemas/snowpark/callable.py +++ b/src/snowflake/cli/api/project/schemas/snowpark/callable.py @@ -49,7 +49,7 @@ class Callable(UpdatableModel): @field_validator("runtime") @classmethod - def convert_runtime(cls, runtime_input: str | float) -> str: + def convert_runtime(cls, runtime_input: Union[str,float]) -> str: if isinstance(runtime_input, float): return str(runtime_input) return runtime_input From a2c84c1d4fe1b5dc1272ea5120b05a5cfea430eb Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Wed, 6 Mar 2024 17:24:10 +0100 Subject: [PATCH 06/16] Fixes --- RELEASE-NOTES.md | 1 + src/snowflake/cli/api/project/schemas/snowpark/callable.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index dc9d351f31..966f5778fb 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -11,6 +11,7 @@ * Adding `--image-name` option for image name argument in `spcs image-repository list-tags` for consistency with other commands. * Fixed errors during `spcs image-registry login` not being formatted correctly. * Project definition no longer accept extra fields. Any extra field will cause an error. +* Project definition no longer accept extra fields. Any extra field will cause an error. # v2.1.0 diff --git a/src/snowflake/cli/api/project/schemas/snowpark/callable.py b/src/snowflake/cli/api/project/schemas/snowpark/callable.py index f7a5ceae18..29b3ee4dc1 100644 --- a/src/snowflake/cli/api/project/schemas/snowpark/callable.py +++ b/src/snowflake/cli/api/project/schemas/snowpark/callable.py @@ -49,7 +49,7 @@ class Callable(UpdatableModel): @field_validator("runtime") @classmethod - def convert_runtime(cls, runtime_input: Union[str,float]) -> str: + def convert_runtime(cls, runtime_input: Union[str, float]) -> str: if isinstance(runtime_input, float): return str(runtime_input) return runtime_input From fd7a90e7801a18a433725517e8e2e2d39ddba70f Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Wed, 6 Mar 2024 17:43:49 +0100 Subject: [PATCH 07/16] Fixes --- .../schemas/native_app/{path_maping.py => path_mapping.py} | 0 src/snowflake/cli/api/project/schemas/project_definition.py | 2 +- tests/nativeapp/test_run_processor.py | 5 +++-- 3 files changed, 4 insertions(+), 3 deletions(-) rename src/snowflake/cli/api/project/schemas/native_app/{path_maping.py => path_mapping.py} (100%) diff --git a/src/snowflake/cli/api/project/schemas/native_app/path_maping.py b/src/snowflake/cli/api/project/schemas/native_app/path_mapping.py similarity index 100% rename from src/snowflake/cli/api/project/schemas/native_app/path_maping.py rename to src/snowflake/cli/api/project/schemas/native_app/path_mapping.py diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index 4e2792e065..86bab8bb94 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -21,5 +21,5 @@ class ProjectDefinition(UpdatableModel): default=None, ) streamlit: Optional[Streamlit] = Field( - title="Native app definitions for the project", default=None + title="Streamlit definitions for the project", default=None ) diff --git a/tests/nativeapp/test_run_processor.py b/tests/nativeapp/test_run_processor.py index ee0fd55c9f..3591f4cfd6 100644 --- a/tests/nativeapp/test_run_processor.py +++ b/tests/nativeapp/test_run_processor.py @@ -784,14 +784,15 @@ def test_create_dev_app_create_new_quoted( definition_version: 1 native_app: name: '"My Native Application"' - + source_stage: app_src.stage - + artifacts: - setup.sql - app/README.md - src: app/streamlit/*.py + dest: ui/ application: From eefe2383d2a64b4ab56aa4a76c9bb69c50a3bb4f Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Wed, 6 Mar 2024 17:48:35 +0100 Subject: [PATCH 08/16] typo fix --- src/snowflake/cli/api/project/schemas/native_app/native_app.py | 2 +- src/snowflake/cli/plugins/nativeapp/artifacts.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snowflake/cli/api/project/schemas/native_app/native_app.py b/src/snowflake/cli/api/project/schemas/native_app/native_app.py index eea465d3d0..877a341c13 100644 --- a/src/snowflake/cli/api/project/schemas/native_app/native_app.py +++ b/src/snowflake/cli/api/project/schemas/native_app/native_app.py @@ -6,7 +6,7 @@ from pydantic import Field, field_validator from snowflake.cli.api.project.schemas.native_app.application import Application from snowflake.cli.api.project.schemas.native_app.package import Package -from snowflake.cli.api.project.schemas.native_app.path_maping import PathMapping +from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel from snowflake.cli.api.project.util import ( SCHEMA_AND_NAME, diff --git a/src/snowflake/cli/plugins/nativeapp/artifacts.py b/src/snowflake/cli/plugins/nativeapp/artifacts.py index 12bdfe6e52..e623502e75 100644 --- a/src/snowflake/cli/plugins/nativeapp/artifacts.py +++ b/src/snowflake/cli/plugins/nativeapp/artifacts.py @@ -6,7 +6,7 @@ import strictyaml from click import ClickException from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB -from snowflake.cli.api.project.schemas.native_app.path_maping import PathMapping +from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping from snowflake.cli.api.secure_path import SecurePath From 33b89ed539b5f66c5a4aa4965aafaff449737ad4 Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Thu, 7 Mar 2024 10:40:15 +0100 Subject: [PATCH 09/16] Added unit test --- tests/project/test_config.py | 26 ++++++++++++++++++++++++++ tests/project/test_pydantic_schemas.py | 0 2 files changed, 26 insertions(+) delete mode 100644 tests/project/test_pydantic_schemas.py diff --git a/tests/project/test_config.py b/tests/project/test_config.py index 4d06803ed1..4a76cd79c3 100644 --- a/tests/project/test_config.py +++ b/tests/project/test_config.py @@ -89,3 +89,29 @@ def test_does_not_accept_unknown_fields(project_definition_files): "You provided field '('unknown_fields_accepted',)' with value 'true' that is not present in the schema" in str(exc_info) ) + + +@pytest.mark.parametrize( + "project_definition_files", + [ + "integration", + "integration_external", + "minimal", + "napp_project_1", + "napp_project_with_pkg_warehouse", + "snowpark_function_external_access", + "snowpark_function_fully_qualified_name", + "snowpark_function_secrets_without_external_access", + "snowpark_functions", + "snowpark_procedure_external_access", + "snowpark_procedure_fully_qualified_name", + "snowpark_procedure_secrets_without_external_access", + "snowpark_procedures", + "snowpark_procedures_coverage", + "streamlit_full_definition", + ], + indirect=True, +) +def test_fields_are_parsed_correctly(project_definition_files, snapshot): + result = load_project_definition(project_definition_files).model_dump() + assert result == snapshot diff --git a/tests/project/test_pydantic_schemas.py b/tests/project/test_pydantic_schemas.py deleted file mode 100644 index e69de29bb2..0000000000 From 9216782c1369ea7bb2a0bb3af6578692ea16164d Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Thu, 7 Mar 2024 10:41:34 +0100 Subject: [PATCH 10/16] Added unit test --- tests/project/__snapshots__/test_config.ambr | 764 +++++++++++++++++++ 1 file changed, 764 insertions(+) create mode 100644 tests/project/__snapshots__/test_config.ambr diff --git a/tests/project/__snapshots__/test_config.ambr b/tests/project/__snapshots__/test_config.ambr new file mode 100644 index 0000000000..714bfe003e --- /dev/null +++ b/tests/project/__snapshots__/test_config.ambr @@ -0,0 +1,764 @@ +# serializer version: 1 +# name: test_fields_are_parsed_correctly[integration] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': None, + 'artifacts': list([ + dict({ + 'dest': './', + 'src': 'app/*', + }), + ]), + 'deploy_root': 'output/deploy/', + 'name': 'integration', + 'package': dict({ + 'distribution': 'internal', + 'name': None, + 'role': None, + 'scripts': list([ + 'package/001-shared.sql', + 'package/002-shared.sql', + ]), + 'warehouse': None, + }), + 'source_stage': 'app_src.stage', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[integration_external] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': None, + 'artifacts': list([ + dict({ + 'dest': './', + 'src': 'app/*', + }), + ]), + 'deploy_root': 'output/deploy/', + 'name': 'integration_external', + 'package': dict({ + 'distribution': 'external', + 'name': None, + 'role': None, + 'scripts': list([ + 'package/001-shared.sql', + 'package/002-shared.sql', + ]), + 'warehouse': None, + }), + 'source_stage': 'app_src.stage', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[minimal] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': None, + 'artifacts': list([ + 'setup.sql', + 'README.md', + ]), + 'deploy_root': 'output/deploy/', + 'name': 'minimal', + 'package': None, + 'source_stage': 'app_src.stage', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[napp_project_1] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': dict({ + 'debug': True, + 'name': 'myapp_polly', + 'role': 'myapp_consumer', + 'warehouse': None, + }), + 'artifacts': list([ + 'setup.sql', + 'app/README.md', + dict({ + 'dest': 'ui/', + 'src': 'app/streamlit/*.py', + }), + ]), + 'deploy_root': 'output/deploy/', + 'name': 'myapp', + 'package': dict({ + 'distribution': 'internal', + 'name': 'myapp_pkg_polly', + 'role': 'accountadmin', + 'scripts': list([ + '001-shared.sql', + '002-shared.sql', + ]), + 'warehouse': None, + }), + 'source_stage': '"MySourceSchema"."SRC_Stage"', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[napp_project_with_pkg_warehouse] + dict({ + 'definition_version': 1, + 'native_app': dict({ + 'application': dict({ + 'debug': True, + 'name': 'myapp_polly', + 'role': 'myapp_consumer', + 'warehouse': None, + }), + 'artifacts': list([ + 'setup.sql', + 'app/README.md', + dict({ + 'dest': 'ui/', + 'src': 'app/streamlit/*.py', + }), + ]), + 'deploy_root': 'output/deploy/', + 'name': 'myapp', + 'package': dict({ + 'distribution': 'internal', + 'name': 'myapp_pkg_polly', + 'role': 'accountadmin', + 'scripts': list([ + '001-shared.sql', + '002-shared.sql', + ]), + 'warehouse': 'myapp_pkg_warehouse', + }), + 'source_stage': '"MySourceSchema"."SRC_Stage"', + }), + 'snowpark': None, + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_function_external_access] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + dict({ + 'database': None, + 'external_access_integrations': list([ + 'external_1', + 'external_2', + ]), + 'handler': 'app.func1_handler', + 'imports': list([ + ]), + 'name': 'func1', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': dict({ + 'cred': 'cred_name', + 'other': 'other_name', + }), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'a', + }), + dict({ + 'arg_type': 'variant', + 'default': None, + 'name': 'b', + }), + ]), + }), + ]), + 'procedures': list([ + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_function_fully_qualified_name] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'custom_db.custom_schema.fqn_function', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'custom_schema.fqn_function_only_schema', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'schema_function', + 'returns': 'string', + 'runtime': None, + 'schema_name': 'custom_schema', + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_db', + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'database_function', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_db', + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'custom_schema.database_function', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_database', + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_function', + 'imports': list([ + ]), + 'name': 'custom_database.custom_schema.fqn_function_error', + 'returns': 'string', + 'runtime': None, + 'schema_name': 'custom_schema', + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'procedures': list([ + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_function_secrets_without_external_access] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.func1_handler', + 'imports': list([ + ]), + 'name': 'func1', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': dict({ + 'cred': 'cred_name', + 'other': 'other_name', + }), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'a', + }), + dict({ + 'arg_type': 'variant', + 'default': None, + 'name': 'b', + }), + ]), + }), + ]), + 'procedures': list([ + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_functions] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + dict({ + 'database': None, + 'external_access_integrations': list([ + ]), + 'handler': 'app.func1_handler', + 'imports': list([ + ]), + 'name': 'func1', + 'returns': 'string', + 'runtime': '3.10', + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': 'default value', + 'name': 'a', + }), + dict({ + 'arg_type': 'variant', + 'default': None, + 'name': 'b', + }), + ]), + }), + ]), + 'procedures': list([ + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedure_external_access] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + 'external_1', + 'external_2', + ]), + 'handler': 'app.hello', + 'imports': list([ + ]), + 'name': 'procedureName', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': dict({ + 'cred': 'cred_name', + 'other': 'other_name', + }), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedure_fully_qualified_name] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'custom_db.custom_schema.fqn_procedure', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'custom_schema.fqn_procedure_only_schema', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'schema_procedure', + 'returns': 'string', + 'runtime': None, + 'schema_name': 'custom_schema', + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_db', + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'database_procedure', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_db', + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'custom_schema.database_procedure', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': 'custom_database', + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello_procedure', + 'imports': list([ + ]), + 'name': 'custom_database.custom_schema.fqn_procedure_error', + 'returns': 'string', + 'runtime': None, + 'schema_name': 'custom_schema', + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedure_secrets_without_external_access] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'app.hello', + 'imports': list([ + ]), + 'name': 'procedureName', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': dict({ + 'cred': 'cred_name', + 'other': 'other_name', + }), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedures] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'hello', + 'imports': list([ + ]), + 'name': 'procedureName', + 'returns': 'string', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'test', + 'imports': list([ + ]), + 'name': 'test', + 'returns': 'string', + 'runtime': '3.10', + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': '', + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[snowpark_procedures_coverage] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': dict({ + 'functions': list([ + ]), + 'procedures': list([ + dict({ + 'database': None, + 'execute_as_caller': False, + 'external_access_integrations': list([ + ]), + 'handler': 'foo.func', + 'imports': list([ + ]), + 'name': 'foo', + 'returns': 'variant', + 'runtime': None, + 'schema_name': None, + 'secrets': list([ + ]), + 'signature': list([ + dict({ + 'arg_type': 'string', + 'default': None, + 'name': 'name', + }), + ]), + }), + ]), + 'project_name': 'my_snowpark_project', + 'src': 'app/', + 'stage_name': 'dev_deployment', + }), + 'streamlit': None, + }) +# --- +# name: test_fields_are_parsed_correctly[streamlit_full_definition] + dict({ + 'definition_version': 1, + 'native_app': None, + 'snowpark': None, + 'streamlit': dict({ + 'additional_source_files': list([ + 'utils/utils.py', + 'extra_file.py', + ]), + 'env_file': 'environment.yml', + 'main_file': 'streamlit_app.py', + 'name': 'test_streamlit', + 'pages_dir': 'pages', + 'query_warehouse': 'test_warehouse', + 'stage': 'streamlit', + }), + }) +# --- From e0db3fd281f25cebec223e1e9c493b3acb2d5c9e Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Thu, 7 Mar 2024 13:54:52 +0100 Subject: [PATCH 11/16] Fixes after review --- RELEASE-NOTES.md | 1 - pyproject.toml | 1 - src/snowflake/cli/api/project/schemas/native_app/package.py | 4 ++-- src/snowflake/cli/api/project/schemas/project_definition.py | 4 +++- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 966f5778fb..dc9d351f31 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -11,7 +11,6 @@ * Adding `--image-name` option for image name argument in `spcs image-repository list-tags` for consistency with other commands. * Fixed errors during `spcs image-registry login` not being formatted correctly. * Project definition no longer accept extra fields. Any extra field will cause an error. -* Project definition no longer accept extra fields. Any extra field will cause an error. # v2.1.0 diff --git a/pyproject.toml b/pyproject.toml index 00af57691e..cc874927ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "requirements-parser==0.5.0", "setuptools==69.1.1", "snowflake-connector-python[secure-local-storage]==3.7.1", - "strictyaml==1.7.3", "tomlkit==0.12.3", "typer==0.9.0", "urllib3>=1.21.1,<2.3", diff --git a/src/snowflake/cli/api/project/schemas/native_app/package.py b/src/snowflake/cli/api/project/schemas/native_app/package.py index 3934209a55..3562ffe0e9 100644 --- a/src/snowflake/cli/api/project/schemas/native_app/package.py +++ b/src/snowflake/cli/api/project/schemas/native_app/package.py @@ -19,7 +19,7 @@ class Package(UpdatableModel): default=None, ) name: Optional[str] = IdentifierField( - title="Name of the application created when you run the snow app run command", # TODO: this description seems duplicated, is it ok? + title="Name of the application package created when you run the snow app run command", default=None, ) warehouse: Optional[str] = IdentifierField( @@ -35,6 +35,6 @@ class Package(UpdatableModel): def validate_scripts(cls, input_list): if len(input_list) != len(set(input_list)): raise ValueError( - "Scripts field should contain unique values. Check the list for duplicates and try again" + "package.scripts field should contain unique values. Check the list for duplicates and try again" ) return input_list diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index 86bab8bb94..1e3bcecd67 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -11,7 +11,9 @@ class ProjectDefinition(UpdatableModel): definition_version: int = Field( - title="Version of the project definition schema, which is currently 1" + title="Version of the project definition schema, which is currently 1", + ge=1, + le=1, ) native_app: Optional[NativeApp] = Field( title="Native app definitions for the project", default=None From bea3edc822ecf40117be8579b4cd3cb1a7bd5bef Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Thu, 7 Mar 2024 15:35:41 +0100 Subject: [PATCH 12/16] Fixes after review --- src/snowflake/cli/plugins/nativeapp/artifacts.py | 4 ++-- tests_integration/conftest.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/snowflake/cli/plugins/nativeapp/artifacts.py b/src/snowflake/cli/plugins/nativeapp/artifacts.py index e623502e75..cc4aa67228 100644 --- a/src/snowflake/cli/plugins/nativeapp/artifacts.py +++ b/src/snowflake/cli/plugins/nativeapp/artifacts.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import List, Optional, Tuple, Union -import strictyaml +import yaml from click import ClickException from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping @@ -271,7 +271,7 @@ def find_version_info_in_manifest_file( with SecurePath(manifest_file).open( "r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB ) as file: - manifest_content = strictyaml.load(file.read()) + manifest_content = yaml.load(file.read(), Loader=yaml.BaseLoader) version_name: Optional[str] = None patch_name: Optional[str] = None diff --git a/tests_integration/conftest.py b/tests_integration/conftest.py index 633c7d4046..7313a67d2e 100644 --- a/tests_integration/conftest.py +++ b/tests_integration/conftest.py @@ -11,11 +11,11 @@ from typing import Any, Dict, List, Optional import pytest -import strictyaml +import yaml + from snowflake.cli.api.cli_global_context import cli_context_manager from snowflake.cli.api.project.definition import merge_left from snowflake.cli.app.cli_app import app_factory -from strictyaml import as_document from typer import Typer from typer.testing import CliRunner @@ -131,10 +131,12 @@ def _temporary_project_directory( test_data_file = test_root_path / "test_data" / "projects" / project_name shutil.copytree(test_data_file, temporary_working_directory, dirs_exist_ok=True) if merge_project_definition: - project_definition = strictyaml.load(Path("snowflake.yml").read_text()).data + project_definition = yaml.load( + Path("snowflake.yml").read_text(), Loader=yaml.BaseLoader + ) merge_left(project_definition, merge_project_definition) with open(Path(temporary_working_directory) / "snowflake.yml", "w") as file: - file.write(as_document(project_definition).as_yaml()) + file.write(yaml.dump(project_definition)) yield temporary_working_directory From a2283608415e0914966f8ed026858e6695b162f0 Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Fri, 8 Mar 2024 10:16:18 +0100 Subject: [PATCH 13/16] Fixes --- tests/testing_utils/fixtures.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/testing_utils/fixtures.py b/tests/testing_utils/fixtures.py index 5c30a0bb23..cd52a0e4f9 100644 --- a/tests/testing_utils/fixtures.py +++ b/tests/testing_utils/fixtures.py @@ -11,12 +11,11 @@ from unittest import mock import pytest -import strictyaml +import yaml from snowflake.cli.api.project.definition import merge_left from snowflake.cli.app.cli_app import app_factory from snowflake.connector.cursor import SnowflakeCursor from snowflake.connector.errors import ProgrammingError -from strictyaml import as_document from typer import Typer from typer.testing import CliRunner @@ -250,10 +249,12 @@ def _temporary_project_directory( test_data_file = test_root_path / "test_data" / "projects" / project_name shutil.copytree(test_data_file, temp_dir, dirs_exist_ok=True) if merge_project_definition: - project_definition = strictyaml.load(Path("snowflake.yml").read_text()).data + project_definition = yaml.load( + Path("snowflake.yml").read_text(), Loader=yaml.BaseLoader + ) merge_left(project_definition, merge_project_definition) with open(Path(temp_dir) / "snowflake.yml", "w") as file: - file.write(as_document(project_definition).as_yaml()) + file.write(yaml.dump(project_definition)) yield Path(temp_dir) From 8cbd9f42d96148adb0c065ee313c1c5eef1cc74d Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Fri, 8 Mar 2024 12:05:58 +0100 Subject: [PATCH 14/16] Fixes --- pyproject.toml | 1 + src/snowflake/cli/plugins/nativeapp/artifacts.py | 4 ++-- tests/spcs/test_services.py | 6 +++--- tests_integration/conftest.py | 10 ++++------ 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cc874927ad..00af57691e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "requirements-parser==0.5.0", "setuptools==69.1.1", "snowflake-connector-python[secure-local-storage]==3.7.1", + "strictyaml==1.7.3", "tomlkit==0.12.3", "typer==0.9.0", "urllib3>=1.21.1,<2.3", diff --git a/src/snowflake/cli/plugins/nativeapp/artifacts.py b/src/snowflake/cli/plugins/nativeapp/artifacts.py index cc4aa67228..e623502e75 100644 --- a/src/snowflake/cli/plugins/nativeapp/artifacts.py +++ b/src/snowflake/cli/plugins/nativeapp/artifacts.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import List, Optional, Tuple, Union -import yaml +import strictyaml from click import ClickException from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB from snowflake.cli.api.project.schemas.native_app.path_mapping import PathMapping @@ -271,7 +271,7 @@ def find_version_info_in_manifest_file( with SecurePath(manifest_file).open( "r", read_file_limit_mb=DEFAULT_SIZE_LIMIT_MB ) as file: - manifest_content = yaml.load(file.read(), Loader=yaml.BaseLoader) + manifest_content = strictyaml.load(file.read()) version_name: Optional[str] = None patch_name: Optional[str] = None diff --git a/tests/spcs/test_services.py b/tests/spcs/test_services.py index e7ee202e75..0f6e73750e 100644 --- a/tests/spcs/test_services.py +++ b/tests/spcs/test_services.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, patch import pytest -import strictyaml +import yaml from click import ClickException from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.project.util import to_string_literal @@ -185,9 +185,9 @@ def test_create_service_with_invalid_spec(mock_read_yaml): max_instances = 42 external_access_integrations = query_warehouse = tags = comment = None auto_resume = False - mock_read_yaml.side_effect = strictyaml.YAMLError("Invalid YAML") + mock_read_yaml.side_effect = yaml.YAMLError("Invalid yaml") - with pytest.raises(strictyaml.YAMLError): + with pytest.raises(yaml.YAMLError): ServiceManager().create( service_name=service_name, compute_pool=compute_pool, diff --git a/tests_integration/conftest.py b/tests_integration/conftest.py index 7313a67d2e..633c7d4046 100644 --- a/tests_integration/conftest.py +++ b/tests_integration/conftest.py @@ -11,11 +11,11 @@ from typing import Any, Dict, List, Optional import pytest -import yaml - +import strictyaml from snowflake.cli.api.cli_global_context import cli_context_manager from snowflake.cli.api.project.definition import merge_left from snowflake.cli.app.cli_app import app_factory +from strictyaml import as_document from typer import Typer from typer.testing import CliRunner @@ -131,12 +131,10 @@ def _temporary_project_directory( test_data_file = test_root_path / "test_data" / "projects" / project_name shutil.copytree(test_data_file, temporary_working_directory, dirs_exist_ok=True) if merge_project_definition: - project_definition = yaml.load( - Path("snowflake.yml").read_text(), Loader=yaml.BaseLoader - ) + project_definition = strictyaml.load(Path("snowflake.yml").read_text()).data merge_left(project_definition, merge_project_definition) with open(Path(temporary_working_directory) / "snowflake.yml", "w") as file: - file.write(yaml.dump(project_definition)) + file.write(as_document(project_definition).as_yaml()) yield temporary_working_directory From 54f641c3aed2474ddb7a5506cf2866f10576e9c6 Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Fri, 8 Mar 2024 12:25:27 +0100 Subject: [PATCH 15/16] Fixes --- tests/spcs/test_services.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/spcs/test_services.py b/tests/spcs/test_services.py index 0f6e73750e..e7ee202e75 100644 --- a/tests/spcs/test_services.py +++ b/tests/spcs/test_services.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, patch import pytest -import yaml +import strictyaml from click import ClickException from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.project.util import to_string_literal @@ -185,9 +185,9 @@ def test_create_service_with_invalid_spec(mock_read_yaml): max_instances = 42 external_access_integrations = query_warehouse = tags = comment = None auto_resume = False - mock_read_yaml.side_effect = yaml.YAMLError("Invalid yaml") + mock_read_yaml.side_effect = strictyaml.YAMLError("Invalid YAML") - with pytest.raises(yaml.YAMLError): + with pytest.raises(strictyaml.YAMLError): ServiceManager().create( service_name=service_name, compute_pool=compute_pool, From 75d74673a5ba9842b4f17957c16e7546d81235a8 Mon Sep 17 00:00:00 2001 From: Jan Sikorski Date: Fri, 8 Mar 2024 16:38:23 +0100 Subject: [PATCH 16/16] Fixes --- .../cli/api/project/schemas/native_app/application.py | 8 ++++---- .../cli/api/project/schemas/native_app/native_app.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/snowflake/cli/api/project/schemas/native_app/application.py b/src/snowflake/cli/api/project/schemas/native_app/application.py index 623815dc61..e7383bcc39 100644 --- a/src/snowflake/cli/api/project/schemas/native_app/application.py +++ b/src/snowflake/cli/api/project/schemas/native_app/application.py @@ -11,19 +11,19 @@ class Application(UpdatableModel): role: Optional[str] = Field( - title="Role to use when creating the application instance and consumer-side objects", + title="Role to use when creating the application object and consumer-side objects", default=None, ) name: Optional[str] = Field( - title="Name of the application created when you run the snow app run command", + title="Name of the application object created when you run the snow app run command", default=None, ) warehouse: Optional[str] = IdentifierField( - title="Name of the application created when you run the snow app run command", + title="Name of the application object created when you run the snow app run command", default=None, ) debug: Optional[bool] = Field( - title="Whether to enable debug mode when using a named stage to create an application", + title="Whether to enable debug mode when using a named stage to create an application object", default=True, ) diff --git a/src/snowflake/cli/api/project/schemas/native_app/native_app.py b/src/snowflake/cli/api/project/schemas/native_app/native_app.py index 877a341c13..97714cc188 100644 --- a/src/snowflake/cli/api/project/schemas/native_app/native_app.py +++ b/src/snowflake/cli/api/project/schemas/native_app/native_app.py @@ -35,5 +35,5 @@ class NativeApp(UpdatableModel): @classmethod def validate_source_stage(cls, input_value: str): if not re.match(SCHEMA_AND_NAME, input_value): - raise ValueError("Incorrect value for Native Apps source stage value") + raise ValueError("Incorrect value for source_stage value of native_app") return input_value