Skip to content

Commit

Permalink
Merge pull request #500 from splitgraph/bulk-repo-external-upload
Browse files Browse the repository at this point in the history
[CU-15xpkpf] Sgr bulk repositores upload
  • Loading branch information
gruuya authored Jul 26, 2021
2 parents 58e936e + db3b0fc commit 1a2529f
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 117 deletions.
3 changes: 3 additions & 0 deletions examples/postgrest/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ services:
- 5432
postgrest:
image: postgrest/postgrest:latest
command:
- postgrest
- /etc/postgrest.conf
ports:
- '0.0.0.0:8080:8080'
volumes:
Expand Down
160 changes: 149 additions & 11 deletions splitgraph/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from splitgraph.__version__ import __version__
from splitgraph.cloud.models import (
Credential,
Metadata,
MetadataResponse,
External,
Expand All @@ -27,6 +28,7 @@
AddExternalCredentialRequest,
UpdateExternalCredentialResponse,
AddExternalRepositoryRequest,
AddExternalRepositoriesRequest,
)
from splitgraph.commandline.engine import patch_and_save_config
from splitgraph.config import create_config_dict, get_singleton, CONFIG
Expand Down Expand Up @@ -57,6 +59,69 @@ def get_headers():
}
}

_BULK_UPSERT_REPO_PROFILES_QUERY = """mutation BulkUpsertRepoProfilesMutation(
$namespaces: [String!]
$repositories: [String!]
$descriptions: [String]
$readmes: [String]
$licenses: [String]
$metadata: [JSON]
) {
__typename
bulkUpsertRepoProfiles(
input: {
namespaces: $namespaces
repositories: $repositories
descriptions: $descriptions
readmes: $readmes
licenses: $licenses
metadata: $metadata
}
) {
clientMutationId
__typename
}
}
"""

_BULK_UPDATE_REPO_SOURCES_QUERY = """mutation BulkUpdateRepoSourcesMutation(
$namespaces: [String!]
$repositories: [String!]
$sources: [DatasetSourceInput]
) {
__typename
bulkUpdateRepoSources(
input: {
namespaces: $namespaces
repositories: $repositories
sources: $sources
}
) {
clientMutationId
__typename
}
}
"""

_BULK_UPSERT_REPO_TOPICS_QUERY = """mutation BulkUpsertRepoTopicsMutation(
$namespaces: [String!]
$repositories: [String!]
$topics: [String]
) {
__typename
bulkUpsertRepoTopics(
input: {
namespaces: $namespaces
repositories: $repositories
topics: $topics
}
) {
clientMutationId
__typename
}
}
"""

_PROFILE_UPSERT_QUERY = """mutation UpsertRepoProfile(
$namespace: String!
$repository: String!
Expand Down Expand Up @@ -567,17 +632,11 @@ def ensure_external_credential(
assert credential
return credential.credential_id

def upsert_external(
self,
namespace: str,
repository: str,
external: External,
credentials_map: Optional[Dict[str, str]] = None,
):
request = AddExternalRepositoryRequest.from_external(
namespace, repository, external, credentials_map
def bulk_upsert_external(self, repositories: List[AddExternalRepositoryRequest]):
request = AddExternalRepositoriesRequest(repositories=repositories)
self._perform_request(
"/bulk-add", self.access_token, request, endpoint=self.externals_endpoint
)
self._perform_request("/add", self.access_token, request, endpoint=self.externals_endpoint)


def AuthAPIClient(*args, **kwargs):
Expand Down Expand Up @@ -633,7 +692,7 @@ def _gql(self, query: Dict, endpoint=None, handle_errors=False) -> requests.Resp
return result

@staticmethod
def _prepare_upsert_metadata_gql(namespace: str, repository: str, metadata: Metadata, v1=False):
def _validate_metadata(namespace: str, repository: str, metadata: Metadata):
# Pre-flight validation
if metadata.description and len(metadata.description) > 160:
raise ValueError("The description should be 160 characters or shorter!")
Expand Down Expand Up @@ -669,6 +728,12 @@ def _prepare_upsert_metadata_gql(namespace: str, repository: str, metadata: Meta
if "readme" in variables and isinstance(variables["readme"], dict):
variables["readme"] = variables["readme"]["text"]

return variables

@staticmethod
def _prepare_upsert_metadata_gql(namespace: str, repository: str, metadata: Metadata, v1=False):
variables = GQLAPIClient._validate_metadata(namespace, repository, metadata)

gql_query = _PROFILE_UPSERT_QUERY
if v1:
gql_query = gql_query.replace("createRepoTopicsAgg", "createRepoTopic").replace(
Expand Down Expand Up @@ -706,6 +771,79 @@ def upsert_metadata(self, namespace: str, repository: str, metadata: Metadata):
)
return response

def bulk_upsert_metadata(
self, namespace_list: List[str], repository_list: List[str], metadata_list: List[Metadata]
):
repo_profiles: Dict[str, List[Any]] = dict(
namespaces=namespace_list,
repositories=repository_list,
descriptions=[],
readmes=[],
licenses=[],
metadata=[],
)
repo_sources: Dict[str, List[Any]] = dict(namespaces=[], repositories=[], sources=[])
repo_topics: Dict[str, List[str]] = dict(namespaces=[], repositories=[], topics=[])

# populate mutation payloads
for ind, metadata in enumerate(metadata_list):
validated_metadata = GQLAPIClient._validate_metadata(
namespace_list[ind], repository_list[ind], metadata
)

repo_profiles["descriptions"].append(validated_metadata.get("description"))
repo_profiles["readmes"].append(validated_metadata.get("readme"))
repo_profiles["licenses"].append(validated_metadata.get("license"))
repo_profiles["metadata"].append(validated_metadata.get("metadata"))

# flatten sources, which will be aggregated on the server side
if len(validated_metadata.get("sources", [])) > 0:
for source in validated_metadata["sources"]:
repo_sources["namespaces"].append(namespace_list[ind])
repo_sources["repositories"].append(repository_list[ind])
repo_sources["sources"].append(source)

# flatten topics, which will be aggregated on the server side
if len(validated_metadata.get("topics", [])) > 0:
for topic in validated_metadata["topics"]:
repo_topics["namespaces"].append(namespace_list[ind])
repo_topics["repositories"].append(repository_list[ind])
repo_topics["topics"].append(topic)

self._bulk_upsert_repo_profiles(repo_profiles)
self._bulk_upsert_repo_sources(repo_sources)
self._bulk_upsert_repo_topics(repo_topics)

@handle_gql_errors
def _bulk_upsert_repo_profiles(self, repo_profiles: Dict[str, List[Any]]):
repo_profiles_query = {
"operationName": "BulkUpsertRepoProfilesMutation",
"variables": repo_profiles,
"query": _BULK_UPSERT_REPO_PROFILES_QUERY,
}
response = self._gql(repo_profiles_query)
return response

@handle_gql_errors
def _bulk_upsert_repo_sources(self, repo_sources: Dict[str, List[Any]]):
repo_sources_query = {
"operationName": "BulkUpdateRepoSourcesMutation",
"variables": repo_sources,
"query": _BULK_UPDATE_REPO_SOURCES_QUERY,
}
response = self._gql(repo_sources_query)
return response

@handle_gql_errors
def _bulk_upsert_repo_topics(self, repo_topics: Dict[str, List[str]]):
repo_topics_query = {
"operationName": "BulkUpsertRepoTopicsMutation",
"variables": repo_topics,
"query": _BULK_UPSERT_REPO_TOPICS_QUERY,
}
response = self._gql(repo_topics_query)
return response

def upsert_readme(self, namespace: str, repository: str, readme: str):
return self.upsert_metadata(namespace, repository, Metadata(readme=readme))

Expand Down
12 changes: 12 additions & 0 deletions splitgraph/cloud/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,18 @@ class Credential(BaseModel):
data: Dict[str, Any]


class IngestionSchedule(BaseModel):
schedule: str
enabled = True


class External(BaseModel):
credential_id: Optional[str]
credential: Optional[str]
plugin: str
params: Dict[str, Any]
tables: Dict[str, Table]
schedule: Optional[IngestionSchedule]


# Models for the catalog metadata (description, README, topics etc)
Expand Down Expand Up @@ -226,6 +232,7 @@ class AddExternalRepositoryRequest(BaseModel):
is_live: bool
tables: Optional[Dict[str, ExternalTableRequest]]
credential_id: Optional[str]
schedule: Optional[IngestionSchedule]

@classmethod
def from_external(
Expand Down Expand Up @@ -259,4 +266,9 @@ def from_external(
},
credential_id=credential_id,
is_live=True,
schedule=external.schedule,
)


class AddExternalRepositoriesRequest(BaseModel):
repositories: List[AddExternalRepositoryRequest]
37 changes: 26 additions & 11 deletions splitgraph/commandline/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
from click import wrap_text
from tqdm import tqdm

from splitgraph.cloud.models import Metadata, RepositoriesYAML
from splitgraph.cloud.models import Metadata, RepositoriesYAML, AddExternalRepositoryRequest
from splitgraph.commandline.common import (
ImageType,
RepositoryType,
emit_sql_results,
Color,
)
from splitgraph.commandline.engine import patch_and_save_config, inject_config_into_engines
from splitgraph.core.output import pluralise

# Hardcoded database name for the Splitgraph DDN (ddn instead of sgregistry)
from splitgraph.config.config import get_from_subsection
Expand Down Expand Up @@ -675,16 +676,30 @@ def load_c(remote, readme_dir, repositories_file, limit_repositories):
r for r in repositories if f"{r.namespace}/{r.repository}" in limit_repositories
]

with tqdm(repositories) as t:
for repository in t:
t.set_description(f"{repository.namespace}/{repository.repository}")
if repository.external:
rest_client.upsert_external(
repository.namespace, repository.repository, repository.external, credential_map
)
if repository.metadata:
metadata = _prepare_metadata(repository.metadata, readme_basedir=readme_dir)
gql_client.upsert_metadata(repository.namespace, repository.repository, metadata)
logging.info("Uploading images...")
external_repositories = []
for repository in repositories:
if repository.external:
external_repository = AddExternalRepositoryRequest.from_external(
repository.namespace, repository.repository, repository.external, credential_map
)
external_repositories.append(external_repository)
rest_client.bulk_upsert_external(repositories=external_repositories)
logging.info(f"Uploaded images for {pluralise('repository', len(external_repositories))}")

logging.info("Updating metadata...")
namespace_list = []
repository_list = []
metadata_list = []
for repository in repositories:
if repository.metadata:
namespace_list.append(repository.namespace)
repository_list.append(repository.repository)

metadata = _prepare_metadata(repository.metadata, readme_basedir=readme_dir)
metadata_list.append(metadata)
gql_client.bulk_upsert_metadata(namespace_list, repository_list, metadata_list)
logging.info(f"Updated metadata for {pluralise('repository', len(repository_list))}")


def _build_credential_map(auth_client, credentials=None):
Expand Down
2 changes: 2 additions & 0 deletions splitgraph/core/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def pretty_size(size: Union[int, float]) -> str:

def pluralise(word: str, number: int) -> str:
"""1 banana, 2 bananas"""
if word.endswith("y"):
return "%d %s" % (number, word if number == 1 else word[:-1] + "ies")
return "%d %s%s" % (number, word, "" if number == 1 else "s")


Expand Down
13 changes: 4 additions & 9 deletions splitgraph/engine/postgres/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
# the connection property otherwise
from psycopg2._psycopg import connection as Connection

psycopg2.extensions.register_adapter(dict, Json)

_AUDIT_SCHEMA = "splitgraph_audit"
_AUDIT_TRIGGER = "resources/static/audit_trigger.sql"
_PUSH_PULL = "resources/static/splitgraph_api.sql"
Expand Down Expand Up @@ -511,7 +513,7 @@ def run_sql(
with connection.cursor(**cursor_kwargs) as cur:
try:
self.notices = []
cur.execute(statement, _convert_vals(arguments) if arguments else None)
cur.execute(statement, arguments)
if connection.notices:
self.notices = connection.notices[:]
del connection.notices[:]
Expand Down Expand Up @@ -603,7 +605,7 @@ def run_sql_batch(
batches = _paginate_by_size(
cur,
statement,
(_convert_vals(a) for a in arguments),
arguments,
max_size=API_MAX_QUERY_LENGTH,
)
for batch in batches:
Expand Down Expand Up @@ -1603,13 +1605,6 @@ def _convert_audit_change(
_KIND = {"I": 0, "D": 1, "U": 2}


def _convert_vals(vals: Any) -> Any:
"""Psycopg returns jsonb objects as dicts/lists but doesn't actually accept them directly
as a query param (or in the case of lists coerces them into an array.
Hence, we have to wrap them in the Json datatype when doing a dump + load."""
return [Json(v) if isinstance(v, dict) else v for v in vals]


def _generate_where_clause(table: str, cols: List[str], table_2: str) -> Composed:
return SQL(" AND ").join(
SQL("{}.{} = {}.{}").format(
Expand Down
Loading

0 comments on commit 1a2529f

Please sign in to comment.