diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index cdd9d17dc..e163c4daf 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -17,7 +17,7 @@ import google.auth.exceptions import google.cloud.bigquery import google.cloud.exceptions -from google.api_core import retry, client_info +from google.api_core import retry, client_info, client_options from google.auth import impersonated_credentials from google.oauth2 import ( credentials as GoogleCredentials, @@ -126,6 +126,7 @@ class BigQueryCredentials(Credentials): schema: Optional[str] = None execution_project: Optional[str] = None location: Optional[str] = None + api_endpoint: Optional[str] = None priority: Optional[Priority] = None maximum_bytes_billed: Optional[int] = None impersonate_service_account: Optional[str] = None @@ -405,6 +406,11 @@ def get_bigquery_client(cls, profile_credentials): creds = cls.get_credentials(profile_credentials) execution_project = profile_credentials.execution_project location = getattr(profile_credentials, "location", None) + options: Optional[client_options.ClientOptions] = None + if profile_credentials.api_endpoint: + options = client_options.ClientOptions( + api_endpoint=profile_credentials.api_endpoint, + ) info = client_info.ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}") return google.cloud.bigquery.Client( @@ -412,6 +418,7 @@ def get_bigquery_client(cls, profile_credentials): creds, location=location, client_info=info, + client_options=options, ) @classmethod diff --git a/tests/unit/test_bigquery_adapter.py b/tests/unit/test_bigquery_adapter.py index a922525fd..8d67a1a97 100644 --- a/tests/unit/test_bigquery_adapter.py +++ b/tests/unit/test_bigquery_adapter.py @@ -17,7 +17,7 @@ from dbt.adapters.bigquery.relation_configs import PartitionConfig from dbt.adapters.bigquery import BigQueryAdapter, BigQueryRelation from google.cloud.bigquery.table import Table -from dbt.adapters.bigquery.connections import _sanitize_label, _VALIDATE_LABEL_LENGTH_LIMIT +from dbt.adapters.bigquery.connections import _sanitize_label, _VALIDATE_LABEL_LENGTH_LIMIT, get_bigquery_defaults from dbt_common.clients import agate_helper import dbt_common.exceptions from dbt.context.query_header import generate_query_header_context @@ -71,6 +71,17 @@ def setUp(self): "priority": "batch", "maximum_bytes_billed": 0, }, + "api_endpoint": { + "type": "bigquery", + "method": "oauth", + "project": "dbt-unit-000000", + "schema": "dummy_schema", + "threads": 1, + "location": "Luna Station", + "priority": "batch", + "maximum_bytes_billed": 0, + "api_endpoint": "https://localhost:3001", + }, "impersonate": { "type": "bigquery", "method": "oauth", @@ -389,6 +400,7 @@ def test_cancel_open_connections_single(self): @patch("dbt.adapters.bigquery.impl.google.auth.default") @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") def test_location_user_agent(self, mock_bq, mock_auth_default): + get_bigquery_defaults.cache_clear() creds = MagicMock() mock_auth_default.return_value = (creds, MagicMock()) adapter = self.get_adapter("loc") @@ -403,8 +415,42 @@ def test_location_user_agent(self, mock_bq, mock_auth_default): creds, location="Luna Station", client_info=HasUserAgent(), + client_options=None, ) + @patch("dbt.adapters.bigquery.impl.google.auth.default") + @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") + def test_api_endpoint_settable(self, mock_bq, mock_auth_default): + """Ensure that a user can pass api_endpoint to the connector eg. for emulator.""" + + get_bigquery_defaults.cache_clear() + creds = MagicMock() + mock_auth_default.return_value = (creds, MagicMock()) + adapter = self.get_adapter("api_endpoint") + + connection = adapter.acquire_connection("dummy") + mock_client = mock_bq.Client + + mock_client.assert_not_called() + connection.handle + mock_client.assert_called_once_with( + "dbt-unit-000000", + creds, + location="Luna Station", + client_info=HasUserAgent(), + client_options=_CheckApiEndpointSet("https://localhost:3001"), + ) + + +class _CheckApiEndpointSet: + value: str + + def __init__(self, value: str) -> None: + self.value = value + + def __eq__(self, other) -> bool: + return getattr(other, "api_endpoint") == self.value + class HasUserAgent: PAT = re.compile(r"dbt-bigquery-\d+\.\d+\.\d+((a|b|rc)\d+)?")