Skip to content

Commit

Permalink
Snow 1105689 git copy (#843)
Browse files Browse the repository at this point in the history
* git copy: get

* git copy: copy

* unit tests

* target help message: a directory; create target if not exists

* Use SecurePath

* fix path type

* review fixes

* Refactor stage.get

* assert in callback

* SecurePath: make file type assertions public

* Review fixes
  • Loading branch information
sfc-gh-pczajka authored Mar 7, 2024
1 parent 4789196 commit 836653e
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 35 deletions.
4 changes: 4 additions & 0 deletions src/snowflake/cli/api/utils/path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ def path_resolver(path_to_file: str) -> str:
if 0 < return_value <= BUFFER_SIZE:
return buffer.value
return path_to_file


def is_stage_path(path: str) -> bool:
return path.startswith("@") or path.startswith("snow://")
77 changes: 70 additions & 7 deletions src/snowflake/cli/plugins/git/commands.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
from pathlib import Path

import typer
from click import ClickException
from snowflake.cli.api.commands.flags import identifier_argument
from snowflake.cli.api.commands.snow_typer import SnowTyper
from snowflake.cli.api.output.types import CommandResult, QueryResult
from snowflake.cli.api.utils.path_utils import is_stage_path
from snowflake.cli.plugins.git.manager import GitManager

app = SnowTyper(
Expand All @@ -13,54 +16,114 @@
)
log = logging.getLogger(__name__)


def _repo_path_argument_callback(path):
if not is_stage_path(path):
raise ClickException(
"REPOSITORY_PATH should be a path to git repository stage with scope provided."
" For example: @my_repo/branches/main/"
)
return path


RepoNameArgument = identifier_argument(sf_object="git repository", example="my_repo")
RepoPathArgument = typer.Argument(
metavar="REPOSITORY_PATH",
help=(
"Path to git repository stage with scope provided."
" Path to the repository root must end with '/'."
" For example: @my_repo/branches/main/"
),
callback=_repo_path_argument_callback,
)


@app.command(
"list-branches",
help="List all branches in the repository.",
requires_connection=True,
)
def list_branches(
repository_name: str = RepoNameArgument,
**options,
) -> CommandResult:
"""
List all branches in the repository.
"""
return QueryResult(GitManager().show_branches(repo_name=repository_name))


@app.command(
"list-tags",
help="List all tags in the repository.",
requires_connection=True,
)
def list_tags(
repository_name: str = RepoNameArgument,
**options,
) -> CommandResult:
"""
List all tags in the repository.
"""
return QueryResult(GitManager().show_tags(repo_name=repository_name))


@app.command(
"list-files",
help="List files from given state of git repository.",
requires_connection=True,
)
def list_files(
repository_path: str = typer.Argument(
help="Path to git repository stage with scope provided. For example: @my_repo/branches/main"
),
repository_path: str = RepoPathArgument,
**options,
) -> CommandResult:
"""
List files from given state of git repository.
"""
return QueryResult(GitManager().show_files(repo_path=repository_path))


@app.command(
"fetch",
help="Fetch changes from origin to snowflake repository.",
requires_connection=True,
)
def fetch(
repository_name: str = RepoNameArgument,
**options,
) -> CommandResult:
"""
Fetch changes from origin to snowflake repository.
"""
return QueryResult(GitManager().fetch(repo_name=repository_name))


@app.command(
"copy",
requires_connection=True,
)
def copy(
repository_path: str = RepoPathArgument,
destination_path: str = typer.Argument(
help="Target path for copy operation. Should be a path to a directory on remote stage or local file system.",
),
parallel: int = typer.Option(
4,
help="Number of parallel threads to use when downloading files.",
),
**options,
):
"""
Copies all files from given state of repository to local directory or stage.
If the source path ends with '/', the command copies contents of specified directory.
Otherwise, it creates a new directory or file in the destination directory.
"""
is_copy = is_stage_path(destination_path)
if is_copy:
cursor = GitManager().copy_files(
source_path=repository_path, destination_path=destination_path
)
else:
cursor = GitManager().get(
stage_path=repository_path,
dest_path=Path(destination_path).resolve(),
parallel=parallel,
)
return QueryResult(cursor)
13 changes: 7 additions & 6 deletions src/snowflake/cli/plugins/git/manager.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.cli.plugins.object.stage.manager import StageManager
from snowflake.connector.cursor import SnowflakeCursor


class GitManager(SqlExecutionMixin):
def show_branches(self, repo_name: str):
class GitManager(StageManager):
def show_branches(self, repo_name: str) -> SnowflakeCursor:
query = f"show git branches in {repo_name}"
return self._execute_query(query)

def show_tags(self, repo_name: str):
def show_tags(self, repo_name: str) -> SnowflakeCursor:
query = f"show git tags in {repo_name}"
return self._execute_query(query)

def show_files(self, repo_path: str):
def show_files(self, repo_path: str) -> SnowflakeCursor:
query = f"ls {repo_path}"
return self._execute_query(query)

def fetch(self, repo_name: str):
def fetch(self, repo_name: str) -> SnowflakeCursor:
query = f"alter git repository {repo_name} fetch"
return self._execute_query(query)
11 changes: 4 additions & 7 deletions src/snowflake/cli/plugins/object/stage/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
QueryResult,
SingleQueryResult,
)
from snowflake.cli.api.utils.path_utils import is_stage_path
from snowflake.cli.plugins.object.stage.diff import DiffResult
from snowflake.cli.plugins.object.stage.manager import StageManager

Expand All @@ -31,10 +32,6 @@ def stage_list(stage_name: str = StageNameArgument, **options) -> CommandResult:
return QueryResult(cursor)


def _is_stage_path(path: str):
return path.startswith("@") or path.startswith("snow://")


@app.command("copy", requires_connection=True)
def copy(
source_path: str = typer.Argument(
Expand All @@ -57,8 +54,8 @@ def copy(
Copies all files from target path to target directory. This works for both uploading
to and downloading files from the stage.
"""
is_get = _is_stage_path(source_path)
is_put = _is_stage_path(destination_path)
is_get = is_stage_path(source_path)
is_put = is_stage_path(destination_path)

if is_get and is_put:
raise click.ClickException(
Expand All @@ -72,7 +69,7 @@ def copy(
if is_get:
target = Path(destination_path).resolve()
cursor = StageManager().get(
stage_name=source_path, dest_path=target, parallel=parallel
stage_path=source_path, dest_path=target, parallel=parallel
)
else:
source = Path(source_path).resolve()
Expand Down
38 changes: 30 additions & 8 deletions src/snowflake/cli/plugins/object/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Union

from snowflake.cli.api.project.util import to_string_literal
from snowflake.cli.api.secure_path import SecurePath
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.cli.api.utils.path_utils import path_resolver
from snowflake.connector.cursor import SnowflakeCursor
Expand All @@ -19,13 +20,19 @@

class StageManager(SqlExecutionMixin):
@staticmethod
def get_standard_stage_name(name: str) -> str:
def get_standard_stage_prefix(name: str) -> str:
# Handle embedded stages
if name.startswith("snow://") or name.startswith("@"):
return name

return f"@{name}"

@staticmethod
def get_standard_stage_directory_path(path):
if not path.endswith("/"):
path += "/"
return StageManager.get_standard_stage_prefix(path)

@staticmethod
def get_stage_name_from_path(path: str):
"""
Expand All @@ -39,7 +46,7 @@ def quote_stage_name(name: str) -> str:
if name.startswith("'") and name.endswith("'"):
return name # already quoted

standard_name = StageManager.get_standard_stage_name(name)
standard_name = StageManager.get_standard_stage_prefix(name)
if standard_name.startswith("@") and not re.fullmatch(
r"@([\w./$])+", standard_name
):
Expand All @@ -54,16 +61,24 @@ def _to_uri(self, local_path: str):
return to_string_literal(uri)

def list_files(self, stage_name: str) -> SnowflakeCursor:
stage_name = self.get_standard_stage_name(stage_name)
stage_name = self.get_standard_stage_prefix(stage_name)
return self._execute_query(f"ls {self.quote_stage_name(stage_name)}")

@staticmethod
def _assure_is_existing_directory(path: Path) -> None:
spath = SecurePath(path)
if not spath.exists():
spath.mkdir()
spath.assert_is_directory()

def get(
self, stage_name: str, dest_path: Path, parallel: int = 4
self, stage_path: str, dest_path: Path, parallel: int = 4
) -> SnowflakeCursor:
stage_name = self.get_standard_stage_name(stage_name)
stage_path = self.get_standard_stage_prefix(stage_path)
self._assure_is_existing_directory(dest_path)
dest_directory = f"{dest_path}/"
return self._execute_query(
f"get {self.quote_stage_name(stage_name)} {self._to_uri(dest_directory)} parallel={parallel}"
f"get {self.quote_stage_name(stage_path)} {self._to_uri(dest_directory)} parallel={parallel}"
)

def put(
Expand All @@ -81,7 +96,7 @@ def put(
and switch back to the original role for the next commands to run.
"""
with self.use_role(role) if role else nullcontext():
stage_path = self.get_standard_stage_name(stage_path)
stage_path = self.get_standard_stage_prefix(stage_path)
local_resolved_path = path_resolver(str(local_path))
log.info("Uploading %s to @%s", local_resolved_path, stage_path)
cursor = self._execute_query(
Expand All @@ -90,6 +105,13 @@ def put(
)
return cursor

def copy_files(self, source_path: str, destination_path: str) -> SnowflakeCursor:
source = self.get_standard_stage_prefix(source_path)
destination = self.get_standard_stage_directory_path(destination_path)
log.info("Copying files from %s to %s", source, destination)
query = f"copy files into {destination} from {source}"
return self._execute_query(query)

def remove(
self, stage_name: str, path: str, role: Optional[str] = None
) -> SnowflakeCursor:
Expand All @@ -100,7 +122,7 @@ def remove(
and switch back to the original role for the next commands to run.
"""
with self.use_role(role) if role else nullcontext():
stage_name = self.get_standard_stage_name(stage_name)
stage_name = self.get_standard_stage_prefix(stage_name)
path = path if path.startswith("/") else "/" + path
quoted_stage_name = self.quote_stage_name(f"{stage_name}{path}")
return self._execute_query(f"remove {quoted_stage_name}")
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/cli/plugins/streamlit/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def deploy(

stage_manager.create(stage_name=stage_name)

root_location = stage_manager.get_standard_stage_name(
root_location = stage_manager.get_standard_stage_prefix(
f"{stage_name}/{streamlit_name_for_root_location}"
)

Expand Down
Loading

0 comments on commit 836653e

Please sign in to comment.