Skip to content

Commit

Permalink
Merge pull request #262 from preset-io/sc_78260
Browse files Browse the repository at this point in the history
feat(dbt): allow passing custom URL
  • Loading branch information
betodealmeida authored Feb 14, 2024
2 parents a895f98 + 6450771 commit ca81c12
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 6 deletions.
58 changes: 54 additions & 4 deletions src/preset_cli/api/clients/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# pylint: disable=invalid-name, too-few-public-methods

import logging
import re
from enum import Enum
from typing import Any, Dict, List, Optional, Type, TypedDict

Expand Down Expand Up @@ -642,18 +643,67 @@ class DataResponse(TypedDict):
data: Dict[str, Any]


def get_custom_urls(access_url: Optional[str] = None) -> Dict[str, URL]:
"""
Return new custom URLs for dbt Cloud access.
"""
if access_url is None:
return {
"admin": REST_ENDPOINT,
"discovery": METADATA_GRAPHQL_ENDPOINT,
"semantic-layer": SEMANTIC_LAYER_GRAPHQL_ENDPOINT,
}

regex_pattern = r"""
(?P<code>
[a-zA-Z0-9]
(?:
[a-zA-Z0-9-]{0,61}
[a-zA-Z0-9]
)?
)
\.
(?P<region>
[a-zA-Z0-9]
(?:
[a-zA-Z0-9-]{0,61}
[a-zA-Z0-9]
)?
)
\.dbt.com
$
"""

parsed = URL(access_url)
if match := re.match(regex_pattern, parsed.host, re.VERBOSE):
return {
"admin": parsed.with_host(
f"{match['code']}.{match['region']}.dbt.com",
),
"discovery": parsed.with_host(
f"{match['code']}.metadata.{match['region']}.dbt.com",
),
"semantic-layer": parsed.with_host(
f"{match['code']}.semantic-layer.{match['region']}.dbt.com",
),
}

raise Exception("Invalid host in custom URL")


class DBTClient: # pylint: disable=too-few-public-methods

"""
A client for the dbt API.
"""

def __init__(self, auth: Auth):
self.metadata_graphql_client = GraphqlClient(endpoint=METADATA_GRAPHQL_ENDPOINT)
def __init__(self, auth: Auth, access_url: Optional[str] = None):
urls = get_custom_urls(access_url)
self.metadata_graphql_client = GraphqlClient(endpoint=urls["discovery"])
self.semantic_layer_graphql_client = GraphqlClient(
endpoint=SEMANTIC_LAYER_GRAPHQL_ENDPOINT,
endpoint=urls["semantic-layer"],
)
self.baseurl = REST_ENDPOINT
self.baseurl = urls["admin"]

self.session = auth.session
self.session.headers.update(auth.get_headers())
Expand Down
7 changes: 6 additions & 1 deletion src/preset_cli/cli/superset/sync/dbt/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ def process_sl_metrics(
default=False,
help="Update Preset configurations based on dbt metadata. Preset-only metrics are preserved",
)
@click.option(
"--access-url",
help="Custom API URL for dbt Cloud (eg, https://ab123.us1.dbt.com)",
)
@click.pass_context
def dbt_cloud( # pylint: disable=too-many-arguments, too-many-locals
ctx: click.core.Context,
Expand All @@ -436,6 +440,7 @@ def dbt_cloud( # pylint: disable=too-many-arguments, too-many-locals
preserve_columns: bool = False,
preserve_metadata: bool = False,
merge_metadata: bool = False,
access_url: Optional[str] = None,
) -> None:
"""
Sync models/metrics from dbt Cloud to Superset.
Expand All @@ -445,7 +450,7 @@ def dbt_cloud( # pylint: disable=too-many-arguments, too-many-locals
superset_client = SupersetClient(url, superset_auth)

dbt_auth = TokenAuth(token)
dbt_client = DBTClient(dbt_auth)
dbt_client = DBTClient(dbt_auth, access_url)

if (preserve_columns or preserve_metadata) and merge_metadata:
click.echo(
Expand Down
24 changes: 23 additions & 1 deletion tests/api/clients/dbt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from marshmallow import fields
from pytest_mock import MockerFixture
from requests_mock.mocker import Mocker
from yarl import URL

from preset_cli import __version__
from preset_cli.api.clients.dbt import DBTClient, PostelEnumField
from preset_cli.api.clients.dbt import DBTClient, PostelEnumField, get_custom_urls
from preset_cli.auth.main import Auth


Expand Down Expand Up @@ -1340,3 +1341,24 @@ def test_dbt_client_get_sl_dialect(mocker: MockerFixture) -> None:
client = DBTClient(auth)

assert client.get_sl_dialect(108380) == "BIGQUERY"


def test_get_custom_urls() -> None:
"""
Test the ``get_custom_urls`` function.
"""
assert get_custom_urls() == {
"admin": URL("https://cloud.getdbt.com/"),
"discovery": URL("https://metadata.cloud.getdbt.com/graphql"),
"semantic-layer": URL("https://semantic-layer.cloud.getdbt.com/api/graphql"),
}

assert get_custom_urls("https://ab123.us1.dbt.com") == {
"admin": URL("https://ab123.us1.dbt.com"),
"discovery": URL("https://ab123.metadata.us1.dbt.com"),
"semantic-layer": URL("https://ab123.semantic-layer.us1.dbt.com"),
}

with pytest.raises(Exception) as excinfo:
get_custom_urls("https://preset.io")
assert str(excinfo.value) == "Invalid host in custom URL"

0 comments on commit ca81c12

Please sign in to comment.