diff --git a/src/snowflake/cli/api/utils/path_utils.py b/src/snowflake/cli/api/utils/path_utils.py index bd05a003c4..9a090cf3e0 100644 --- a/src/snowflake/cli/api/utils/path_utils.py +++ b/src/snowflake/cli/api/utils/path_utils.py @@ -16,3 +16,7 @@ def path_resolver(path_to_file: str) -> str: if 0 < return_value <= BUFFER_SIZE: return buffer.value return path_to_file + + +def is_stage_path(path: str) -> bool: + return path.startswith("@") or path.startswith("snow://") diff --git a/src/snowflake/cli/plugins/git/commands.py b/src/snowflake/cli/plugins/git/commands.py index 240f2aea37..13be906546 100644 --- a/src/snowflake/cli/plugins/git/commands.py +++ b/src/snowflake/cli/plugins/git/commands.py @@ -1,9 +1,12 @@ import logging +from pathlib import Path import typer +from click import ClickException from snowflake.cli.api.commands.flags import identifier_argument from snowflake.cli.api.commands.snow_typer import SnowTyper from snowflake.cli.api.output.types import CommandResult, QueryResult +from snowflake.cli.api.utils.path_utils import is_stage_path from snowflake.cli.plugins.git.manager import GitManager app = SnowTyper( @@ -13,54 +16,114 @@ ) log = logging.getLogger(__name__) + +def _repo_path_argument_callback(path): + if not is_stage_path(path): + raise ClickException( + "REPOSITORY_PATH should be a path to git repository stage with scope provided." + " For example: @my_repo/branches/main/" + ) + return path + + RepoNameArgument = identifier_argument(sf_object="git repository", example="my_repo") +RepoPathArgument = typer.Argument( + metavar="REPOSITORY_PATH", + help=( + "Path to git repository stage with scope provided." + " Path to the repository root must end with '/'." + " For example: @my_repo/branches/main/" + ), + callback=_repo_path_argument_callback, +) @app.command( "list-branches", - help="List all branches in the repository.", requires_connection=True, ) def list_branches( repository_name: str = RepoNameArgument, **options, ) -> CommandResult: + """ + List all branches in the repository. + """ return QueryResult(GitManager().show_branches(repo_name=repository_name)) @app.command( "list-tags", - help="List all tags in the repository.", requires_connection=True, ) def list_tags( repository_name: str = RepoNameArgument, **options, ) -> CommandResult: + """ + List all tags in the repository. + """ return QueryResult(GitManager().show_tags(repo_name=repository_name)) @app.command( "list-files", - help="List files from given state of git repository.", requires_connection=True, ) def list_files( - repository_path: str = typer.Argument( - help="Path to git repository stage with scope provided. For example: @my_repo/branches/main" - ), + repository_path: str = RepoPathArgument, **options, ) -> CommandResult: + """ + List files from given state of git repository. + """ return QueryResult(GitManager().show_files(repo_path=repository_path)) @app.command( "fetch", - help="Fetch changes from origin to snowflake repository.", requires_connection=True, ) def fetch( repository_name: str = RepoNameArgument, **options, ) -> CommandResult: + """ + Fetch changes from origin to snowflake repository. + """ return QueryResult(GitManager().fetch(repo_name=repository_name)) + + +@app.command( + "copy", + requires_connection=True, +) +def copy( + repository_path: str = RepoPathArgument, + destination_path: str = typer.Argument( + help="Target path for copy operation. Should be a path to a directory on remote stage or local file system.", + ), + parallel: int = typer.Option( + 4, + help="Number of parallel threads to use when downloading files.", + ), + **options, +): + """ + Copies all files from given state of repository to local directory or stage. + + If the source path ends with '/', the command copies contents of specified directory. + Otherwise, it creates a new directory or file in the destination directory. + """ + is_copy = is_stage_path(destination_path) + if is_copy: + cursor = GitManager().copy_files( + source_path=repository_path, destination_path=destination_path + ) + else: + cursor = GitManager().get( + stage_path=repository_path, + dest_path=Path(destination_path).resolve(), + parallel=parallel, + ) + return QueryResult(cursor) diff --git a/src/snowflake/cli/plugins/git/manager.py b/src/snowflake/cli/plugins/git/manager.py index 4cf357aab5..04e4816f4b 100644 --- a/src/snowflake/cli/plugins/git/manager.py +++ b/src/snowflake/cli/plugins/git/manager.py @@ -1,19 +1,20 @@ -from snowflake.cli.api.sql_execution import SqlExecutionMixin +from snowflake.cli.plugins.object.stage.manager import StageManager +from snowflake.connector.cursor import SnowflakeCursor -class GitManager(SqlExecutionMixin): - def show_branches(self, repo_name: str): +class GitManager(StageManager): + def show_branches(self, repo_name: str) -> SnowflakeCursor: query = f"show git branches in {repo_name}" return self._execute_query(query) - def show_tags(self, repo_name: str): + def show_tags(self, repo_name: str) -> SnowflakeCursor: query = f"show git tags in {repo_name}" return self._execute_query(query) - def show_files(self, repo_path: str): + def show_files(self, repo_path: str) -> SnowflakeCursor: query = f"ls {repo_path}" return self._execute_query(query) - def fetch(self, repo_name: str): + def fetch(self, repo_name: str) -> SnowflakeCursor: query = f"alter git repository {repo_name} fetch" return self._execute_query(query) diff --git a/src/snowflake/cli/plugins/object/stage/commands.py b/src/snowflake/cli/plugins/object/stage/commands.py index f5282f9ee6..592009672e 100644 --- a/src/snowflake/cli/plugins/object/stage/commands.py +++ b/src/snowflake/cli/plugins/object/stage/commands.py @@ -11,6 +11,7 @@ QueryResult, SingleQueryResult, ) +from snowflake.cli.api.utils.path_utils import is_stage_path from snowflake.cli.plugins.object.stage.diff import DiffResult from snowflake.cli.plugins.object.stage.manager import StageManager @@ -31,10 +32,6 @@ def stage_list(stage_name: str = StageNameArgument, **options) -> CommandResult: return QueryResult(cursor) -def _is_stage_path(path: str): - return path.startswith("@") or path.startswith("snow://") - - @app.command("copy", requires_connection=True) def copy( source_path: str = typer.Argument( @@ -57,8 +54,8 @@ def copy( Copies all files from target path to target directory. This works for both uploading to and downloading files from the stage. """ - is_get = _is_stage_path(source_path) - is_put = _is_stage_path(destination_path) + is_get = is_stage_path(source_path) + is_put = is_stage_path(destination_path) if is_get and is_put: raise click.ClickException( @@ -72,7 +69,7 @@ def copy( if is_get: target = Path(destination_path).resolve() cursor = StageManager().get( - stage_name=source_path, dest_path=target, parallel=parallel + stage_path=source_path, dest_path=target, parallel=parallel ) else: source = Path(source_path).resolve() diff --git a/src/snowflake/cli/plugins/object/stage/manager.py b/src/snowflake/cli/plugins/object/stage/manager.py index 8baa291818..d369824493 100644 --- a/src/snowflake/cli/plugins/object/stage/manager.py +++ b/src/snowflake/cli/plugins/object/stage/manager.py @@ -7,6 +7,7 @@ from typing import Optional, Union from snowflake.cli.api.project.util import to_string_literal +from snowflake.cli.api.secure_path import SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.cli.api.utils.path_utils import path_resolver from snowflake.connector.cursor import SnowflakeCursor @@ -19,13 +20,19 @@ class StageManager(SqlExecutionMixin): @staticmethod - def get_standard_stage_name(name: str) -> str: + def get_standard_stage_prefix(name: str) -> str: # Handle embedded stages if name.startswith("snow://") or name.startswith("@"): return name return f"@{name}" + @staticmethod + def get_standard_stage_directory_path(path): + if not path.endswith("/"): + path += "/" + return StageManager.get_standard_stage_prefix(path) + @staticmethod def get_stage_name_from_path(path: str): """ @@ -39,7 +46,7 @@ def quote_stage_name(name: str) -> str: if name.startswith("'") and name.endswith("'"): return name # already quoted - standard_name = StageManager.get_standard_stage_name(name) + standard_name = StageManager.get_standard_stage_prefix(name) if standard_name.startswith("@") and not re.fullmatch( r"@([\w./$])+", standard_name ): @@ -54,16 +61,24 @@ def _to_uri(self, local_path: str): return to_string_literal(uri) def list_files(self, stage_name: str) -> SnowflakeCursor: - stage_name = self.get_standard_stage_name(stage_name) + stage_name = self.get_standard_stage_prefix(stage_name) return self._execute_query(f"ls {self.quote_stage_name(stage_name)}") + @staticmethod + def _assure_is_existing_directory(path: Path) -> None: + spath = SecurePath(path) + if not spath.exists(): + spath.mkdir() + spath.assert_is_directory() + def get( - self, stage_name: str, dest_path: Path, parallel: int = 4 + self, stage_path: str, dest_path: Path, parallel: int = 4 ) -> SnowflakeCursor: - stage_name = self.get_standard_stage_name(stage_name) + stage_path = self.get_standard_stage_prefix(stage_path) + self._assure_is_existing_directory(dest_path) dest_directory = f"{dest_path}/" return self._execute_query( - f"get {self.quote_stage_name(stage_name)} {self._to_uri(dest_directory)} parallel={parallel}" + f"get {self.quote_stage_name(stage_path)} {self._to_uri(dest_directory)} parallel={parallel}" ) def put( @@ -81,7 +96,7 @@ def put( and switch back to the original role for the next commands to run. """ with self.use_role(role) if role else nullcontext(): - stage_path = self.get_standard_stage_name(stage_path) + stage_path = self.get_standard_stage_prefix(stage_path) local_resolved_path = path_resolver(str(local_path)) log.info("Uploading %s to @%s", local_resolved_path, stage_path) cursor = self._execute_query( @@ -90,6 +105,13 @@ def put( ) return cursor + def copy_files(self, source_path: str, destination_path: str) -> SnowflakeCursor: + source = self.get_standard_stage_prefix(source_path) + destination = self.get_standard_stage_directory_path(destination_path) + log.info("Copying files from %s to %s", source, destination) + query = f"copy files into {destination} from {source}" + return self._execute_query(query) + def remove( self, stage_name: str, path: str, role: Optional[str] = None ) -> SnowflakeCursor: @@ -100,7 +122,7 @@ def remove( and switch back to the original role for the next commands to run. """ with self.use_role(role) if role else nullcontext(): - stage_name = self.get_standard_stage_name(stage_name) + stage_name = self.get_standard_stage_prefix(stage_name) path = path if path.startswith("/") else "/" + path quoted_stage_name = self.quote_stage_name(f"{stage_name}{path}") return self._execute_query(f"remove {quoted_stage_name}") diff --git a/src/snowflake/cli/plugins/streamlit/manager.py b/src/snowflake/cli/plugins/streamlit/manager.py index 3beb317a39..8ae63e0385 100644 --- a/src/snowflake/cli/plugins/streamlit/manager.py +++ b/src/snowflake/cli/plugins/streamlit/manager.py @@ -153,7 +153,7 @@ def deploy( stage_manager.create(stage_name=stage_name) - root_location = stage_manager.get_standard_stage_name( + root_location = stage_manager.get_standard_stage_prefix( f"{stage_name}/{streamlit_name_for_root_location}" ) diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index 472c373ac7..4b5bcedb84 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -782,6 +782,85 @@ ╰──────────────────────────────────────────────────────────────────────────────╯ + ''' +# --- +# name: test_help_messages[git.copy] + ''' + + Usage: default git copy [OPTIONS] REPOSITORY_PATH DESTINATION_PATH + + Copies all files from given state of repository to local directory or stage. + If the source path ends with '/', the command copies contents of specified + directory. Otherwise, it creates a new directory or file in the destination + directory. + + ╭─ Arguments ──────────────────────────────────────────────────────────────────╮ + │ * repository_path TEXT Path to git repository stage with scope │ + │ provided. Path to the repository root must │ + │ end with '/'. For example: │ + │ @my_repo/branches/main/ │ + │ [default: None] │ + │ [required] │ + │ * destination_path TEXT Target path for copy operation. Should be a │ + │ path to a directory on remote stage or │ + │ local file system. │ + │ [default: None] │ + │ [required] │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ Options ────────────────────────────────────────────────────────────────────╮ + │ --parallel INTEGER Number of parallel threads to use when │ + │ downloading files. │ + │ [default: 4] │ + │ --help -h Show this message and exit. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ Connection configuration ───────────────────────────────────────────────────╮ + │ --connection,--environment -c TEXT Name of the connection, as defined │ + │ in your `config.toml`. Default: │ + │ `default`. │ + │ --account,--accountname TEXT Name assigned to your Snowflake │ + │ account. Overrides the value │ + │ specified for the connection. │ + │ --user,--username TEXT Username to connect to Snowflake. │ + │ Overrides the value specified for │ + │ the connection. │ + │ --password TEXT Snowflake password. Overrides the │ + │ value specified for the │ + │ connection. │ + │ --authenticator TEXT Snowflake authenticator. Overrides │ + │ the value specified for the │ + │ connection. │ + │ --private-key-path TEXT Snowflake private key path. │ + │ Overrides the value specified for │ + │ the connection. │ + │ --database,--dbname TEXT Database to use. Overrides the │ + │ value specified for the │ + │ connection. │ + │ --schema,--schemaname TEXT Database schema to use. Overrides │ + │ the value specified for the │ + │ connection. │ + │ --role,--rolename TEXT Role to use. Overrides the value │ + │ specified for the connection. │ + │ --warehouse TEXT Warehouse to use. Overrides the │ + │ value specified for the │ + │ connection. │ + │ --temporary-connection -x Uses connection defined with │ + │ command line parameters, instead │ + │ of one defined in config │ + │ --mfa-passcode TEXT Token to use for multi-factor │ + │ authentication (MFA) │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ Global configuration ───────────────────────────────────────────────────────╮ + │ --format [TABLE|JSON] Specifies the output format. │ + │ [default: TABLE] │ + │ --verbose -v Displays log entries for log levels `info` │ + │ and higher. │ + │ --debug Displays log entries for log levels `debug` │ + │ and higher; debug logs contains additional │ + │ information. │ + │ --silent Turns off intermediate output to console. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + + ''' # --- # name: test_help_messages[git.fetch] @@ -925,8 +1004,9 @@ ╭─ Arguments ──────────────────────────────────────────────────────────────────╮ │ * repository_path TEXT Path to git repository stage with scope │ - │ provided. For example: │ - │ @my_repo/branches/main │ + │ provided. Path to the repository root must │ + │ end with '/'. For example: │ + │ @my_repo/branches/main/ │ │ [default: None] │ │ [required] │ ╰──────────────────────────────────────────────────────────────────────────────╯ @@ -1060,10 +1140,12 @@ │ --help -h Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Commands ───────────────────────────────────────────────────────────────────╮ - │ fetch Fetch changes from origin to snowflake repository. │ - │ list-branches List all branches in the repository. │ - │ list-files List files from given state of git repository. │ - │ list-tags List all tags in the repository. │ + │ copy Copies all files from given state of repository to local │ + │ directory or stage. │ + │ fetch Fetch changes from origin to snowflake repository. │ + │ list-branches List all branches in the repository. │ + │ list-files List files from given state of git repository. │ + │ list-tags List all tags in the repository. │ ╰──────────────────────────────────────────────────────────────────────────────╯ diff --git a/tests/git/test_git_commands.py b/tests/git/test_git_commands.py index 0a7b0c42f4..fab949042b 100644 --- a/tests/git/test_git_commands.py +++ b/tests/git/test_git_commands.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest import mock import pytest @@ -45,6 +46,12 @@ def test_list_files(mock_connector, runner, mock_ctx): assert ctx.get_query() == "ls @repo_name/branches/main" +def test_list_files_not_a_stage_error(runner): + result = runner.invoke(["git", "list-files", "repo_name/branches/main"]) + assert result.exit_code == 1 + _assert_invalid_repo_path_error_message(result.output) + + @mock.patch("snowflake.connector.connect") def test_fetch(mock_connector, runner, mock_ctx): ctx = mock_ctx() @@ -53,3 +60,48 @@ def test_fetch(mock_connector, runner, mock_ctx): assert result.exit_code == 0, result.output assert ctx.get_query() == "alter git repository repo_name fetch" + + +@mock.patch("snowflake.connector.connect") +def test_copy_to_local_file_system(mock_connector, runner, mock_ctx, temp_dir): + ctx = mock_ctx() + mock_connector.return_value = ctx + local_path = Path(temp_dir) / "local_dir" + assert not local_path.exists() + result = runner.invoke(["git", "copy", "@repo_name/branches/main", str(local_path)]) + + assert result.exit_code == 0, result.output + assert local_path.exists() + assert ( + ctx.get_query() + == f"get @repo_name/branches/main file://{local_path.resolve()}/ parallel=4" + ) + + +@mock.patch("snowflake.connector.connect") +def test_copy_to_remote_dir(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + result = runner.invoke( + ["git", "copy", "@repo_name/branches/main", "@stage_path/dir_in_stage"] + ) + + assert result.exit_code == 0, result.output + assert ( + ctx.get_query() + == "copy files into @stage_path/dir_in_stage/ from @repo_name/branches/main" + ) + + +def test_copy_not_a_stage_error(runner): + result = runner.invoke(["git", "copy", "repo_name", "@stage_path/dir_in_stage"]) + assert result.exit_code == 1 + _assert_invalid_repo_path_error_message(result.output) + + +def _assert_invalid_repo_path_error_message(output): + assert "Error" in output + assert ( + "REPOSITORY_PATH should be a path to git repository stage with scope" in output + ) + assert "provided. For example: @my_repo/branches/main/" in output