From a43f86b7a7c60b84cf90fc93cf9f336cc0e38223 Mon Sep 17 00:00:00 2001 From: Matthew McKnight <91097623+McKnight-42@users.noreply.github.com> Date: Wed, 17 Jan 2024 11:05:29 -0600 Subject: [PATCH] Feature/decouple adapters from core (#865) * update RELEASE_BRANCH env * create draft pr to track work * initial migration work * get tests_connections unit test passing * update model_node refs to relation_config, update _macro_manifest_lazy ref in unit test * update new_config ping in dynamic_table configuration_changes macro * minor changes to imports * minor changes, revert dev-requirements pointer * revert runtime_config call * fix list_relations_without_caching * change up test input for schema_relation * update snowflake__get_paginated_relations_array macro schema_relation call * add changelog --------- Co-authored-by: Colin --- .../unreleased/Features-20240109-165520.yaml | 6 +++ dbt/adapters/snowflake/column.py | 2 +- dbt/adapters/snowflake/connections.py | 24 ++++++------ dbt/adapters/snowflake/impl.py | 15 +++---- dbt/adapters/snowflake/relation.py | 8 ++-- .../snowflake/relation_configs/base.py | 17 ++++---- .../relation_configs/dynamic_table.py | 22 +++++------ .../snowflake/relation_configs/policies.py | 2 +- dbt/include/snowflake/macros/adapters.sql | 6 +-- .../macros/materializations/dynamic_table.sql | 2 +- .../test_dynamic_tables_changes.py | 2 +- .../test_list_relations_without_caching.py | 9 ++--- tests/unit/test_connections.py | 6 +-- tests/unit/test_snowflake_adapter.py | 39 ++++++++++--------- tests/unit/utils.py | 4 +- 15 files changed, 86 insertions(+), 78 deletions(-) create mode 100644 .changes/unreleased/Features-20240109-165520.yaml diff --git a/.changes/unreleased/Features-20240109-165520.yaml b/.changes/unreleased/Features-20240109-165520.yaml new file mode 100644 index 000000000..b38770760 --- /dev/null +++ b/.changes/unreleased/Features-20240109-165520.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Update base adapter references as part of decoupling migration +time: 2024-01-09T16:55:20.859657-06:00 +custom: + Author: McKnight-42 + Issue: "882" diff --git a/dbt/adapters/snowflake/column.py b/dbt/adapters/snowflake/column.py index e5d07b82b..61e37f6cc 100644 --- a/dbt/adapters/snowflake/column.py +++ b/dbt/adapters/snowflake/column.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from dbt.adapters.base.column import Column -from dbt.exceptions import DbtRuntimeError +from dbt.common.exceptions import DbtRuntimeError @dataclass diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index b5fa30002..f2f0bddd6 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -11,7 +11,7 @@ from typing import Optional, Tuple, Union, Any, List import agate -import dbt.clients.agate_helper +import dbt.common.clients.agate_helper from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -32,20 +32,20 @@ BindUploadError, ) -from dbt.exceptions import ( +from dbt.common.exceptions import ( DbtInternalError, DbtRuntimeError, - FailedToConnectError, - DbtDatabaseError, - DbtProfileError, + DbtConfigError, ) +from dbt.common.exceptions import DbtDatabaseError from dbt.adapters.base import Credentials # type: ignore -from dbt.contracts.connection import AdapterResponse, Connection +from dbt.adapters.exceptions.connection import FailedToConnectError +from dbt.adapters.contracts.connection import AdapterResponse, Connection from dbt.adapters.sql import SQLConnectionManager # type: ignore -from dbt.events import AdapterLogger # type: ignore -from dbt.events.functions import warn_or_error -from dbt.events.types import AdapterEventWarning -from dbt.ui import line_wrap_message, warning_tag +from dbt.adapters.events.logging import AdapterLogger # type: ignore +from dbt.common.events.functions import warn_or_error +from dbt.adapters.events.types import AdapterEventWarning +from dbt.common.ui import line_wrap_message, warning_tag logger = AdapterLogger("Snowflake") @@ -247,7 +247,7 @@ def _get_access_token(self) -> str: def _get_private_key(self): """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" if self.private_key and self.private_key_path: - raise DbtProfileError("Cannot specify both `private_key` and `private_key_path`") + raise DbtConfigError("Cannot specify both `private_key` and `private_key_path`") if self.private_key_passphrase: encoded_passphrase = self.private_key_passphrase.encode() @@ -476,7 +476,7 @@ def execute( if fetch: table = self.get_result_from_cursor(cursor, limit) else: - table = dbt.clients.agate_helper.empty_table() + table = dbt.common.clients.agate_helper.empty_table() return response, table def add_standard_query(self, sql: str, **kwargs) -> Tuple[Connection, Any]: diff --git a/dbt/adapters/snowflake/impl.py b/dbt/adapters/snowflake/impl.py index 6f71fec1a..40b54b61b 100644 --- a/dbt/adapters/snowflake/impl.py +++ b/dbt/adapters/snowflake/impl.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Mapping, Any, Optional, List, Union, Dict +from typing import Mapping, Any, Optional, List, Union, Dict, FrozenSet, Tuple import agate @@ -15,10 +15,9 @@ from dbt.adapters.snowflake import SnowflakeConnectionManager from dbt.adapters.snowflake import SnowflakeRelation from dbt.adapters.snowflake import SnowflakeColumn -from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.nodes import ConstraintType -from dbt.exceptions import CompilationError, DbtDatabaseError, DbtRuntimeError -from dbt.utils import filter_null_values +from dbt.common.contracts.constraints import ConstraintType +from dbt.common.exceptions import CompilationError, DbtDatabaseError, DbtRuntimeError +from dbt.common.utils import filter_null_values @dataclass @@ -62,11 +61,13 @@ def date_function(cls): return "CURRENT_TIMESTAMP()" @classmethod - def _catalog_filter_table(cls, table: agate.Table, manifest: Manifest) -> agate.Table: + def _catalog_filter_table( + cls, table: agate.Table, used_schemas: FrozenSet[Tuple[str, str]] + ) -> agate.Table: # On snowflake, users can set QUOTED_IDENTIFIERS_IGNORE_CASE, so force # the column names to their lowercased forms. lowered = table.rename(column_names=[c.lower() for c in table.column_names]) - return super()._catalog_filter_table(lowered, manifest) + return super()._catalog_filter_table(lowered, used_schemas) def _make_match_kwargs(self, database, schema, identifier): quoting = self.config.quoting diff --git a/dbt/adapters/snowflake/relation.py b/dbt/adapters/snowflake/relation.py index 9d6182a71..325d23c9b 100644 --- a/dbt/adapters/snowflake/relation.py +++ b/dbt/adapters/snowflake/relation.py @@ -3,8 +3,8 @@ from dbt.adapters.base.relation import BaseRelation from dbt.adapters.relation_configs import RelationConfigChangeAction, RelationResults -from dbt.context.providers import RuntimeConfigObject -from dbt.utils import classproperty +from dbt.adapters.contracts.relation import RelationConfig +from dbt.adapters.utils import classproperty from dbt.adapters.snowflake.relation_configs import ( SnowflakeDynamicTableConfig, @@ -43,12 +43,12 @@ def get_relation_type(cls) -> Type[SnowflakeRelationType]: @classmethod def dynamic_table_config_changeset( - cls, relation_results: RelationResults, runtime_config: RuntimeConfigObject + cls, relation_results: RelationResults, relation_config: RelationConfig ) -> Optional[SnowflakeDynamicTableConfigChangeset]: existing_dynamic_table = SnowflakeDynamicTableConfig.from_relation_results( relation_results ) - new_dynamic_table = SnowflakeDynamicTableConfig.from_model_node(runtime_config.model) + new_dynamic_table = SnowflakeDynamicTableConfig.from_relation_config(relation_config) config_change_collection = SnowflakeDynamicTableConfigChangeset() diff --git a/dbt/adapters/snowflake/relation_configs/base.py b/dbt/adapters/snowflake/relation_configs/base.py index d7f9f121b..7b4367e2d 100644 --- a/dbt/adapters/snowflake/relation_configs/base.py +++ b/dbt/adapters/snowflake/relation_configs/base.py @@ -1,14 +1,13 @@ from dataclasses import dataclass from typing import Any, Dict, Optional - import agate from dbt.adapters.base.relation import Policy from dbt.adapters.relation_configs import ( RelationConfigBase, RelationResults, ) -from dbt.contracts.graph.nodes import ModelNode -from dbt.contracts.relation import ComponentName + +from dbt.adapters.contracts.relation import ComponentName, RelationConfig from dbt.adapters.snowflake.relation_configs.policies import ( SnowflakeIncludePolicy, @@ -31,22 +30,22 @@ def quote_policy(cls) -> Policy: return SnowflakeQuotePolicy() @classmethod - def from_model_node(cls, model_node: ModelNode): - relation_config = cls.parse_model_node(model_node) - relation = cls.from_dict(relation_config) + def from_relation_config(cls, relation_config: RelationConfig): + relation_config_dict = cls.parse_relation_config(relation_config) + relation = cls.from_dict(relation_config_dict) return relation @classmethod - def parse_model_node(cls, model_node: ModelNode) -> Dict[str, Any]: + def parse_relation_config(cls, relation_config: RelationConfig) -> Dict: raise NotImplementedError( - "`parse_model_node()` needs to be implemented on this RelationConfigBase instance" + "`parse_relation_config()` needs to be implemented on this RelationConfigBase instance" ) @classmethod def from_relation_results(cls, relation_results: RelationResults): relation_config = cls.parse_relation_results(relation_results) relation = cls.from_dict(relation_config) - return relation + return relation # type: ignore @classmethod def parse_relation_results(cls, relation_results: RelationResults) -> Dict[str, Any]: diff --git a/dbt/adapters/snowflake/relation_configs/dynamic_table.py b/dbt/adapters/snowflake/relation_configs/dynamic_table.py index 6caa7211e..cc1b9112d 100644 --- a/dbt/adapters/snowflake/relation_configs/dynamic_table.py +++ b/dbt/adapters/snowflake/relation_configs/dynamic_table.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict, Any import agate from dbt.adapters.relation_configs import RelationConfigChange, RelationResults -from dbt.contracts.graph.nodes import ModelNode -from dbt.contracts.relation import ComponentName +from dbt.adapters.contracts.relation import RelationConfig +from dbt.adapters.contracts.relation import ComponentName from dbt.adapters.snowflake.relation_configs.base import SnowflakeRelationConfigBase @@ -48,20 +48,20 @@ def from_dict(cls, config_dict) -> "SnowflakeDynamicTableConfig": return dynamic_table @classmethod - def parse_model_node(cls, model_node: ModelNode) -> dict: + def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]: config_dict = { - "name": model_node.identifier, - "schema_name": model_node.schema, - "database_name": model_node.database, - "query": model_node.compiled_code, - "target_lag": model_node.config.extra.get("target_lag"), - "snowflake_warehouse": model_node.config.extra.get("snowflake_warehouse"), + "name": relation_config.identifier, + "schema_name": relation_config.schema, + "database_name": relation_config.database, + "query": relation_config.compiled_code, # type: ignore + "target_lag": relation_config.config.extra.get("target_lag"), # type: ignore + "snowflake_warehouse": relation_config.config.extra.get("snowflake_warehouse"), # type: ignore } return config_dict @classmethod - def parse_relation_results(cls, relation_results: RelationResults) -> dict: + def parse_relation_results(cls, relation_results: RelationResults) -> Dict: dynamic_table: agate.Row = relation_results["dynamic_table"].rows[0] config_dict = { diff --git a/dbt/adapters/snowflake/relation_configs/policies.py b/dbt/adapters/snowflake/relation_configs/policies.py index 31f8e0bc8..f0872f992 100644 --- a/dbt/adapters/snowflake/relation_configs/policies.py +++ b/dbt/adapters/snowflake/relation_configs/policies.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from dbt.adapters.base.relation import Policy -from dbt.dataclass_schema import StrEnum +from dbt.common.dataclass_schema import StrEnum class SnowflakeRelationType(StrEnum): diff --git a/dbt/include/snowflake/macros/adapters.sql b/dbt/include/snowflake/macros/adapters.sql index b2b496356..157738187 100644 --- a/dbt/include/snowflake/macros/adapters.sql +++ b/dbt/include/snowflake/macros/adapters.sql @@ -73,7 +73,7 @@ {% for _ in range(0, max_iter) %} {%- set paginated_sql -%} - show terse objects in {{ schema_relation }} limit {{ max_results_per_iter }} from '{{ watermark.table_name }}' + show terse objects in {{ schema_relation.database }}.{{ schema_relation.schema }} limit {{ max_results_per_iter }} from '{{ watermark.table_name }}' {%- endset -%} {%- set paginated_result = run_query(paginated_sql) %} @@ -96,7 +96,7 @@ {%- if loop.index == max_iter -%} {%- set msg -%} - dbt will list a maximum of {{ max_total_results }} objects in schema {{ schema_relation }}. + dbt will list a maximum of {{ max_total_results }} objects in schema {{ schema_relation.database }}.{{ schema_relation.schema }}. Your schema exceeds this limit. Please contact support@getdbt.com for troubleshooting tips, or review and reduce the number of objects contained. {%- endset -%} @@ -124,7 +124,7 @@ {%- set max_total_results = max_results_per_iter * max_iter -%} {%- set sql -%} - show terse objects in {{ schema_relation }} limit {{ max_results_per_iter }} + show terse objects in {{ schema_relation.database }}.{{ schema_relation.schema }} limit {{ max_results_per_iter }} {%- endset -%} {%- set result = run_query(sql) -%} diff --git a/dbt/include/snowflake/macros/materializations/dynamic_table.sql b/dbt/include/snowflake/macros/materializations/dynamic_table.sql index 23dedb65e..f491ef3bd 100644 --- a/dbt/include/snowflake/macros/materializations/dynamic_table.sql +++ b/dbt/include/snowflake/macros/materializations/dynamic_table.sql @@ -92,6 +92,6 @@ {% macro snowflake__get_dynamic_table_configuration_changes(existing_relation, new_config) -%} {% set _existing_dynamic_table = snowflake__describe_dynamic_table(existing_relation) %} - {% set _configuration_changes = existing_relation.dynamic_table_config_changeset(_existing_dynamic_table, new_config) %} + {% set _configuration_changes = existing_relation.dynamic_table_config_changeset(_existing_dynamic_table, new_config.model) %} {% do return(_configuration_changes) %} {%- endmacro %} diff --git a/tests/functional/adapter/dynamic_table_tests/test_dynamic_tables_changes.py b/tests/functional/adapter/dynamic_table_tests/test_dynamic_tables_changes.py index a88adf398..17d14ebd8 100644 --- a/tests/functional/adapter/dynamic_table_tests/test_dynamic_tables_changes.py +++ b/tests/functional/adapter/dynamic_table_tests/test_dynamic_tables_changes.py @@ -2,7 +2,7 @@ import pytest -from dbt.contracts.graph.model_config import OnConfigurationChangeOption +from dbt.common.contracts.config.materialization import OnConfigurationChangeOption from dbt.tests.util import ( assert_message_in_logs, get_model_file, diff --git a/tests/functional/adapter/test_list_relations_without_caching.py b/tests/functional/adapter/test_list_relations_without_caching.py index f6dfc2144..b126984a3 100644 --- a/tests/functional/adapter/test_list_relations_without_caching.py +++ b/tests/functional/adapter/test_list_relations_without_caching.py @@ -101,7 +101,7 @@ def test__snowflake__list_relations_without_caching_termination(self, project): schemas = project.created_schemas for schema in schemas: - schema_relation = f"{database}.{schema}" + schema_relation = {"database": database, "schema": schema} kwargs = {"schema_relation": schema_relation} _, log_output = run_dbt_and_capture( [ @@ -149,7 +149,7 @@ def test__snowflake__list_relations_without_caching(self, project): schemas = project.created_schemas for schema in schemas: - schema_relation = f"{database}.{schema}" + schema_relation = {"database": database, "schema": schema} kwargs = {"schema_relation": schema_relation} _, log_output = run_dbt_and_capture( [ @@ -161,7 +161,6 @@ def test__snowflake__list_relations_without_caching(self, project): str(kwargs), ] ) - parsed_logs = parse_json_logs(log_output) n_relations = find_result_in_parsed_logs(parsed_logs, "n_relations") @@ -178,7 +177,8 @@ def test__snowflake__list_relations_without_caching_raise_error(self, project): schemas = project.created_schemas for schema in schemas: - schema_relation = f"{database}.{schema}" + schema_relation = {"database": database, "schema": schema} + kwargs = {"schema_relation": schema_relation} _, log_output = run_dbt_and_capture( [ @@ -191,7 +191,6 @@ def test__snowflake__list_relations_without_caching_raise_error(self, project): ], expect_pass=False, ) - parsed_logs = parse_json_logs(log_output) traceback = find_exc_info_in_parsed_logs(parsed_logs, "Traceback") assert "dbt will list a maximum of 99 objects in schema " in traceback diff --git a/tests/unit/test_connections.py b/tests/unit/test_connections.py index 87b0cf4c2..555091c57 100644 --- a/tests/unit/test_connections.py +++ b/tests/unit/test_connections.py @@ -2,13 +2,13 @@ from importlib import reload from unittest.mock import Mock import dbt.adapters.snowflake.connections as connections -import dbt.events +import dbt.adapters.events.logging def test_connections_sets_logs_in_response_to_env_var(monkeypatch): """Test that setting the DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING environment variable happens on import""" log_mock = Mock() - monkeypatch.setattr(dbt.events, "AdapterLogger", Mock(return_value=log_mock)) + monkeypatch.setattr(dbt.adapters.events.logging, "AdapterLogger", Mock(return_value=log_mock)) monkeypatch.setattr(os, "environ", {"DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING": "true"}) reload(connections) @@ -18,7 +18,7 @@ def test_connections_sets_logs_in_response_to_env_var(monkeypatch): def test_connections_does_not_set_logs_in_response_to_env_var(monkeypatch): log_mock = Mock() - monkeypatch.setattr(dbt.events, "AdapterLogger", Mock(return_value=log_mock)) + monkeypatch.setattr(dbt.adapters.events.logging, "AdapterLogger", Mock(return_value=log_mock)) reload(connections) assert log_mock.debug.call_count == 0 diff --git a/tests/unit/test_snowflake_adapter.py b/tests/unit/test_snowflake_adapter.py index 85bbb3859..19f2165d1 100644 --- a/tests/unit/test_snowflake_adapter.py +++ b/tests/unit/test_snowflake_adapter.py @@ -1,6 +1,7 @@ import agate import re import unittest +from multiprocessing import get_context from contextlib import contextmanager from unittest import mock @@ -8,10 +9,11 @@ from dbt.adapters.snowflake import Plugin as SnowflakePlugin from dbt.adapters.snowflake.column import SnowflakeColumn from dbt.adapters.snowflake.connections import SnowflakeCredentials -from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.contracts.files import FileHash +from dbt.context.manifest import generate_query_header_context +from dbt.context.providers import generate_runtime_macro_context from dbt.contracts.graph.manifest import ManifestStateCheck -from dbt.clients import agate_helper +from dbt.common.clients import agate_helper from snowflake import connector as snowflake_connector from .utils import ( @@ -78,10 +80,11 @@ def _mock_state_check(self): self.mock_state_check.side_effect = _mock_state_check self.snowflake.return_value = self.handle - self.adapter = SnowflakeAdapter(self.config) - self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config) - self.adapter.connections.query_header = MacroQueryStringSetter( - self.config, self.adapter._macro_manifest_lazy + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) + self.adapter.set_macro_resolver(load_internal_manifest_macros(self.config)) + self.adapter.set_macro_context_generator(generate_runtime_macro_context) + self.adapter.connections.set_query_header( + generate_query_header_context(self.config, self.adapter.get_macro_resolver()) ) self.qh_patch = mock.patch.object(self.adapter.connections.query_header, "add") @@ -294,7 +297,7 @@ def test_client_session_keep_alive_false_by_default(self): def test_client_session_keep_alive_true(self): self.config.credentials = self.config.credentials.replace(client_session_keep_alive=True) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -321,7 +324,7 @@ def test_client_session_keep_alive_true(self): def test_client_has_query_tag(self): self.config.credentials = self.config.credentials.replace(query_tag="test_query_tag") - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -355,7 +358,7 @@ def test_user_pass_authentication(self): self.config.credentials = self.config.credentials.replace( password="test_password", ) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -386,7 +389,7 @@ def test_authenticator_user_pass_authentication(self): password="test_password", authenticator="test_sso_url", ) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -417,7 +420,7 @@ def test_authenticator_user_pass_authentication(self): def test_authenticator_externalbrowser_authentication(self): self.config.credentials = self.config.credentials.replace(authenticator="externalbrowser") - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -450,7 +453,7 @@ def test_authenticator_oauth_authentication(self): authenticator="oauth", token="my-oauth-token", ) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -488,7 +491,7 @@ def test_authenticator_private_key_authentication(self, mock_get_private_key): private_key_passphrase="p@ssphr@se", ) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -522,7 +525,7 @@ def test_authenticator_private_key_authentication_no_passphrase(self, mock_get_p private_key_passphrase=None, ) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -551,7 +554,7 @@ def test_query_tag(self): self.config.credentials = self.config.credentials.replace( password="test_password", query_tag="test_query_tag" ) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -581,7 +584,7 @@ def test_reuse_connections_with_keep_alive(self): self.config.credentials = self.config.credentials.replace( reuse_connections=True, client_session_keep_alive=True ) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -615,7 +618,7 @@ def test_authenticator_private_key_string_authentication(self, mock_get_private_ private_key_passphrase="p@ssphr@se", ) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() @@ -651,7 +654,7 @@ def test_authenticator_private_key_string_authentication_no_passphrase( private_key_passphrase=None, ) - self.adapter = SnowflakeAdapter(self.config) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") self.snowflake.assert_not_called() diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 991f8d524..042e24bf2 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -9,7 +9,7 @@ import agate import pytest -from dbt.dataclass_schema import ValidationError +from dbt.common.dataclass_schema import ValidationError from dbt.config.project import PartialProject @@ -230,7 +230,7 @@ def assert_fails_validation(dct, cls): class TestAdapterConversions(TestCase): def _get_tester_for(self, column_type): - from dbt.clients import agate_helper + from dbt.common.clients import agate_helper if column_type is agate.TimeDelta: # dbt never makes this! return agate.TimeDelta()