Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: snowflake hints #2143

Draft
wants to merge 1 commit into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions dlt/destinations/impl/snowflake/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,25 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration)
query_tag: Optional[str] = None
"""A tag with placeholders to tag sessions executing jobs"""

# TODO: decide name - create_indexes vs create_constraints (create_indexes used in other destinations)
create_indexes: bool = False
"""Whether UNIQUE or PRIMARY KEY constrains should be created"""

def __init__(
self,
*,
credentials: SnowflakeCredentials = None,
create_indexes: bool = False,
destination_name: str = None,
environment: str = None,
) -> None:
super().__init__(
credentials=credentials,
destination_name=destination_name,
environment=environment,
)
self.create_indexes = create_indexes

def fingerprint(self) -> str:
"""Returns a fingerprint of host part of a connection string"""
if self.credentials and self.credentials.host:
Expand Down
16 changes: 12 additions & 4 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, List
from typing import Optional, Sequence, List, Dict
from urllib.parse import urlparse, urlunparse

from dlt.common.data_writers.configuration import CsvFormatConfiguration
Expand All @@ -17,7 +17,7 @@
)
from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema import TColumnSchema, Schema, TColumnHint
from dlt.common.schema.typing import TColumnType

from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS
Expand All @@ -29,6 +29,8 @@
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
from dlt.destinations.job_impl import ReferenceFollowupJobRequest

SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE", "primary_key": "PRIMARY KEY"}


class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs):
def __init__(
Expand Down Expand Up @@ -238,6 +240,7 @@ def __init__(
self.config: SnowflakeClientConfiguration = config
self.sql_client: SnowflakeSqlClient = sql_client # type: ignore
self.type_mapper = self.capabilities.get_type_mapper()
self.active_hints = SUPPORTED_HINTS if self.config.create_indexes else {}

def create_load_job(
self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False
Expand Down Expand Up @@ -288,9 +291,14 @@ def _from_db_type(
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
hints_str = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
)

def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool:
Expand Down
36 changes: 36 additions & 0 deletions tests/load/snowflake/test_snowflake_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,42 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None:
assert '"COL10" DATE NOT NULL' in sql


def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None:
mod_update = deepcopy(TABLE_UPDATE)

mod_update[0]["primary_key"] = True
mod_update[0]["sort"] = True
mod_update[1]["unique"] = True
mod_update[4]["parent_key"] = True

sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, False))

assert sql.strip().startswith("CREATE TABLE")
assert "EVENT_TEST_TABLE" in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL5" VARCHAR' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL7" BINARY' in sql
assert '"COL8" NUMBER(38,0)' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql

# same thing with indexes
snowflake_client = snowflake().client(
snowflake_client.schema,
SnowflakeClientConfiguration(create_indexes=True)._bind_dataset_name(
dataset_name="test_" + uniq_id()
),
)
sql = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)[0]
sqlfluff.parse(sql)
assert '"COL1" NUMBER(19,0) PRIMARY KEY NOT NULL' in sql
assert '"COL2" FLOAT UNIQUE NOT NULL' in sql


def test_alter_table(snowflake_client: SnowflakeClient) -> None:
statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)
assert len(statements) == 1
Expand Down
Loading