Skip to content

Commit

Permalink
Feature/decouple adapters from core (#865)
Browse files Browse the repository at this point in the history
* update RELEASE_BRANCH env

* create draft pr to track work

* initial migration work

* get tests_connections unit test passing

* update model_node refs to relation_config, update _macro_manifest_lazy ref in unit test

* update new_config ping in dynamic_table configuration_changes macro

* minor changes to imports

* minor changes, revert dev-requirements pointer

* revert runtime_config call

* fix list_relations_without_caching

* change up test input for schema_relation

* update snowflake__get_paginated_relations_array macro schema_relation call

* add changelog

---------

Co-authored-by: Colin <[email protected]>
  • Loading branch information
McKnight-42 and colin-rogers-dbt authored Jan 17, 2024
1 parent 0374b4e commit a43f86b
Show file tree
Hide file tree
Showing 15 changed files with 86 additions and 78 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240109-165520.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Update base adapter references as part of decoupling migration
time: 2024-01-09T16:55:20.859657-06:00
custom:
Author: McKnight-42
Issue: "882"
2 changes: 1 addition & 1 deletion dbt/adapters/snowflake/column.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from dbt.adapters.base.column import Column
from dbt.exceptions import DbtRuntimeError
from dbt.common.exceptions import DbtRuntimeError


@dataclass
Expand Down
24 changes: 12 additions & 12 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Optional, Tuple, Union, Any, List

import agate
import dbt.clients.agate_helper
import dbt.common.clients.agate_helper

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
Expand All @@ -32,20 +32,20 @@
BindUploadError,
)

from dbt.exceptions import (
from dbt.common.exceptions import (
DbtInternalError,
DbtRuntimeError,
FailedToConnectError,
DbtDatabaseError,
DbtProfileError,
DbtConfigError,
)
from dbt.common.exceptions import DbtDatabaseError
from dbt.adapters.base import Credentials # type: ignore
from dbt.contracts.connection import AdapterResponse, Connection
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.contracts.connection import AdapterResponse, Connection
from dbt.adapters.sql import SQLConnectionManager # type: ignore
from dbt.events import AdapterLogger # type: ignore
from dbt.events.functions import warn_or_error
from dbt.events.types import AdapterEventWarning
from dbt.ui import line_wrap_message, warning_tag
from dbt.adapters.events.logging import AdapterLogger # type: ignore
from dbt.common.events.functions import warn_or_error
from dbt.adapters.events.types import AdapterEventWarning
from dbt.common.ui import line_wrap_message, warning_tag


logger = AdapterLogger("Snowflake")
Expand Down Expand Up @@ -247,7 +247,7 @@ def _get_access_token(self) -> str:
def _get_private_key(self):
"""Get Snowflake private key by path, from a Base64 encoded DER bytestring or None."""
if self.private_key and self.private_key_path:
raise DbtProfileError("Cannot specify both `private_key` and `private_key_path`")
raise DbtConfigError("Cannot specify both `private_key` and `private_key_path`")

if self.private_key_passphrase:
encoded_passphrase = self.private_key_passphrase.encode()
Expand Down Expand Up @@ -476,7 +476,7 @@ def execute(
if fetch:
table = self.get_result_from_cursor(cursor, limit)
else:
table = dbt.clients.agate_helper.empty_table()
table = dbt.common.clients.agate_helper.empty_table()
return response, table

def add_standard_query(self, sql: str, **kwargs) -> Tuple[Connection, Any]:
Expand Down
15 changes: 8 additions & 7 deletions dbt/adapters/snowflake/impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Mapping, Any, Optional, List, Union, Dict
from typing import Mapping, Any, Optional, List, Union, Dict, FrozenSet, Tuple

import agate

Expand All @@ -15,10 +15,9 @@
from dbt.adapters.snowflake import SnowflakeConnectionManager
from dbt.adapters.snowflake import SnowflakeRelation
from dbt.adapters.snowflake import SnowflakeColumn
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ConstraintType
from dbt.exceptions import CompilationError, DbtDatabaseError, DbtRuntimeError
from dbt.utils import filter_null_values
from dbt.common.contracts.constraints import ConstraintType
from dbt.common.exceptions import CompilationError, DbtDatabaseError, DbtRuntimeError
from dbt.common.utils import filter_null_values


@dataclass
Expand Down Expand Up @@ -62,11 +61,13 @@ def date_function(cls):
return "CURRENT_TIMESTAMP()"

@classmethod
def _catalog_filter_table(cls, table: agate.Table, manifest: Manifest) -> agate.Table:
def _catalog_filter_table(
cls, table: agate.Table, used_schemas: FrozenSet[Tuple[str, str]]
) -> agate.Table:
# On snowflake, users can set QUOTED_IDENTIFIERS_IGNORE_CASE, so force
# the column names to their lowercased forms.
lowered = table.rename(column_names=[c.lower() for c in table.column_names])
return super()._catalog_filter_table(lowered, manifest)
return super()._catalog_filter_table(lowered, used_schemas)

def _make_match_kwargs(self, database, schema, identifier):
quoting = self.config.quoting
Expand Down
8 changes: 4 additions & 4 deletions dbt/adapters/snowflake/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.relation_configs import RelationConfigChangeAction, RelationResults
from dbt.context.providers import RuntimeConfigObject
from dbt.utils import classproperty
from dbt.adapters.contracts.relation import RelationConfig
from dbt.adapters.utils import classproperty

from dbt.adapters.snowflake.relation_configs import (
SnowflakeDynamicTableConfig,
Expand Down Expand Up @@ -43,12 +43,12 @@ def get_relation_type(cls) -> Type[SnowflakeRelationType]:

@classmethod
def dynamic_table_config_changeset(
cls, relation_results: RelationResults, runtime_config: RuntimeConfigObject
cls, relation_results: RelationResults, relation_config: RelationConfig
) -> Optional[SnowflakeDynamicTableConfigChangeset]:
existing_dynamic_table = SnowflakeDynamicTableConfig.from_relation_results(
relation_results
)
new_dynamic_table = SnowflakeDynamicTableConfig.from_model_node(runtime_config.model)
new_dynamic_table = SnowflakeDynamicTableConfig.from_relation_config(relation_config)

config_change_collection = SnowflakeDynamicTableConfigChangeset()

Expand Down
17 changes: 8 additions & 9 deletions dbt/adapters/snowflake/relation_configs/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional

import agate
from dbt.adapters.base.relation import Policy
from dbt.adapters.relation_configs import (
RelationConfigBase,
RelationResults,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import ComponentName

from dbt.adapters.contracts.relation import ComponentName, RelationConfig

from dbt.adapters.snowflake.relation_configs.policies import (
SnowflakeIncludePolicy,
Expand All @@ -31,22 +30,22 @@ def quote_policy(cls) -> Policy:
return SnowflakeQuotePolicy()

@classmethod
def from_model_node(cls, model_node: ModelNode):
relation_config = cls.parse_model_node(model_node)
relation = cls.from_dict(relation_config)
def from_relation_config(cls, relation_config: RelationConfig):
relation_config_dict = cls.parse_relation_config(relation_config)
relation = cls.from_dict(relation_config_dict)
return relation

@classmethod
def parse_model_node(cls, model_node: ModelNode) -> Dict[str, Any]:
def parse_relation_config(cls, relation_config: RelationConfig) -> Dict:
raise NotImplementedError(
"`parse_model_node()` needs to be implemented on this RelationConfigBase instance"
"`parse_relation_config()` needs to be implemented on this RelationConfigBase instance"
)

@classmethod
def from_relation_results(cls, relation_results: RelationResults):
relation_config = cls.parse_relation_results(relation_results)
relation = cls.from_dict(relation_config)
return relation
return relation # type: ignore

@classmethod
def parse_relation_results(cls, relation_results: RelationResults) -> Dict[str, Any]:
Expand Down
22 changes: 11 additions & 11 deletions dbt/adapters/snowflake/relation_configs/dynamic_table.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Dict, Any

import agate
from dbt.adapters.relation_configs import RelationConfigChange, RelationResults
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import ComponentName
from dbt.adapters.contracts.relation import RelationConfig
from dbt.adapters.contracts.relation import ComponentName

from dbt.adapters.snowflake.relation_configs.base import SnowflakeRelationConfigBase

Expand Down Expand Up @@ -48,20 +48,20 @@ def from_dict(cls, config_dict) -> "SnowflakeDynamicTableConfig":
return dynamic_table

@classmethod
def parse_model_node(cls, model_node: ModelNode) -> dict:
def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]:
config_dict = {
"name": model_node.identifier,
"schema_name": model_node.schema,
"database_name": model_node.database,
"query": model_node.compiled_code,
"target_lag": model_node.config.extra.get("target_lag"),
"snowflake_warehouse": model_node.config.extra.get("snowflake_warehouse"),
"name": relation_config.identifier,
"schema_name": relation_config.schema,
"database_name": relation_config.database,
"query": relation_config.compiled_code, # type: ignore
"target_lag": relation_config.config.extra.get("target_lag"), # type: ignore
"snowflake_warehouse": relation_config.config.extra.get("snowflake_warehouse"), # type: ignore
}

return config_dict

@classmethod
def parse_relation_results(cls, relation_results: RelationResults) -> dict:
def parse_relation_results(cls, relation_results: RelationResults) -> Dict:
dynamic_table: agate.Row = relation_results["dynamic_table"].rows[0]

config_dict = {
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/snowflake/relation_configs/policies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

from dbt.adapters.base.relation import Policy
from dbt.dataclass_schema import StrEnum
from dbt.common.dataclass_schema import StrEnum


class SnowflakeRelationType(StrEnum):
Expand Down
6 changes: 3 additions & 3 deletions dbt/include/snowflake/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
{% for _ in range(0, max_iter) %}

{%- set paginated_sql -%}
show terse objects in {{ schema_relation }} limit {{ max_results_per_iter }} from '{{ watermark.table_name }}'
show terse objects in {{ schema_relation.database }}.{{ schema_relation.schema }} limit {{ max_results_per_iter }} from '{{ watermark.table_name }}'
{%- endset -%}

{%- set paginated_result = run_query(paginated_sql) %}
Expand All @@ -96,7 +96,7 @@

{%- if loop.index == max_iter -%}
{%- set msg -%}
dbt will list a maximum of {{ max_total_results }} objects in schema {{ schema_relation }}.
dbt will list a maximum of {{ max_total_results }} objects in schema {{ schema_relation.database }}.{{ schema_relation.schema }}.
Your schema exceeds this limit. Please contact support@getdbt.com for troubleshooting tips,
or review and reduce the number of objects contained.
{%- endset -%}
Expand Down Expand Up @@ -124,7 +124,7 @@
{%- set max_total_results = max_results_per_iter * max_iter -%}

{%- set sql -%}
show terse objects in {{ schema_relation }} limit {{ max_results_per_iter }}
show terse objects in {{ schema_relation.database }}.{{ schema_relation.schema }} limit {{ max_results_per_iter }}
{%- endset -%}

{%- set result = run_query(sql) -%}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,6 @@

{% macro snowflake__get_dynamic_table_configuration_changes(existing_relation, new_config) -%}
{% set _existing_dynamic_table = snowflake__describe_dynamic_table(existing_relation) %}
{% set _configuration_changes = existing_relation.dynamic_table_config_changeset(_existing_dynamic_table, new_config) %}
{% set _configuration_changes = existing_relation.dynamic_table_config_changeset(_existing_dynamic_table, new_config.model) %}
{% do return(_configuration_changes) %}
{%- endmacro %}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from dbt.contracts.graph.model_config import OnConfigurationChangeOption
from dbt.common.contracts.config.materialization import OnConfigurationChangeOption
from dbt.tests.util import (
assert_message_in_logs,
get_model_file,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test__snowflake__list_relations_without_caching_termination(self, project):
schemas = project.created_schemas

for schema in schemas:
schema_relation = f"{database}.{schema}"
schema_relation = {"database": database, "schema": schema}
kwargs = {"schema_relation": schema_relation}
_, log_output = run_dbt_and_capture(
[
Expand Down Expand Up @@ -149,7 +149,7 @@ def test__snowflake__list_relations_without_caching(self, project):
schemas = project.created_schemas

for schema in schemas:
schema_relation = f"{database}.{schema}"
schema_relation = {"database": database, "schema": schema}
kwargs = {"schema_relation": schema_relation}
_, log_output = run_dbt_and_capture(
[
Expand All @@ -161,7 +161,6 @@ def test__snowflake__list_relations_without_caching(self, project):
str(kwargs),
]
)

parsed_logs = parse_json_logs(log_output)
n_relations = find_result_in_parsed_logs(parsed_logs, "n_relations")

Expand All @@ -178,7 +177,8 @@ def test__snowflake__list_relations_without_caching_raise_error(self, project):
schemas = project.created_schemas

for schema in schemas:
schema_relation = f"{database}.{schema}"
schema_relation = {"database": database, "schema": schema}

kwargs = {"schema_relation": schema_relation}
_, log_output = run_dbt_and_capture(
[
Expand All @@ -191,7 +191,6 @@ def test__snowflake__list_relations_without_caching_raise_error(self, project):
],
expect_pass=False,
)

parsed_logs = parse_json_logs(log_output)
traceback = find_exc_info_in_parsed_logs(parsed_logs, "Traceback")
assert "dbt will list a maximum of 99 objects in schema " in traceback
6 changes: 3 additions & 3 deletions tests/unit/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from importlib import reload
from unittest.mock import Mock
import dbt.adapters.snowflake.connections as connections
import dbt.events
import dbt.adapters.events.logging


def test_connections_sets_logs_in_response_to_env_var(monkeypatch):
"""Test that setting the DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING environment variable happens on import"""
log_mock = Mock()
monkeypatch.setattr(dbt.events, "AdapterLogger", Mock(return_value=log_mock))
monkeypatch.setattr(dbt.adapters.events.logging, "AdapterLogger", Mock(return_value=log_mock))
monkeypatch.setattr(os, "environ", {"DBT_SNOWFLAKE_CONNECTOR_DEBUG_LOGGING": "true"})
reload(connections)

Expand All @@ -18,7 +18,7 @@ def test_connections_sets_logs_in_response_to_env_var(monkeypatch):

def test_connections_does_not_set_logs_in_response_to_env_var(monkeypatch):
log_mock = Mock()
monkeypatch.setattr(dbt.events, "AdapterLogger", Mock(return_value=log_mock))
monkeypatch.setattr(dbt.adapters.events.logging, "AdapterLogger", Mock(return_value=log_mock))
reload(connections)

assert log_mock.debug.call_count == 0
Expand Down
Loading

0 comments on commit a43f86b

Please sign in to comment.