Skip to content

Commit

Permalink
feat(ingest): handle mssql casing issues in lineage (datahub-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Nov 22, 2024
1 parent b5d5db3 commit 1bfd4ee
Show file tree
Hide file tree
Showing 12 changed files with 1,947 additions and 795 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@
# See more below:
# https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm
"teradata",
# For SQL server, the default collation rules mean that all identifiers (schema, table, column names)
# are case preserving but case insensitive.
"mssql",
}
DIALECTS_WITH_DEFAULT_UPPERCASE_COLS = {
# In some dialects, column identifiers are effectively case insensitive
# because they are automatically converted to uppercase. Most other systems
# automatically lowercase unquoted identifiers.
"snowflake",
}
assert DIALECTS_WITH_DEFAULT_UPPERCASE_COLS.issubset(
DIALECTS_WITH_CASE_INSENSITIVE_COLS
)


class QueryType(enum.Enum):
Expand Down
56 changes: 49 additions & 7 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import traceback
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union

import pydantic.dataclasses
import sqlglot
Expand Down Expand Up @@ -873,6 +873,49 @@ def _translate_internal_column_lineage(
)


_StrOrNone = TypeVar("_StrOrNone", str, Optional[str])


def _normalize_db_or_schema(
db_or_schema: _StrOrNone,
dialect: sqlglot.Dialect,
) -> _StrOrNone:
if db_or_schema is None:
return None

# In snowflake, table identifiers must be uppercased to match sqlglot's behavior.
if is_dialect_instance(dialect, "snowflake"):
return db_or_schema.upper()

# In mssql, table identifiers must be lowercased.
elif is_dialect_instance(dialect, "mssql"):
return db_or_schema.lower()

return db_or_schema


def _simplify_select_into(statement: sqlglot.exp.Expression) -> sqlglot.exp.Expression:
"""
Check if the expression is a SELECT INTO statement. If so, converts it into a CTAS.
Other expressions are returned as-is.
"""

if not (isinstance(statement, sqlglot.exp.Select) and statement.args.get("into")):
return statement

# Convert from SELECT <cols> INTO <out> <expr>
# to CREATE TABLE <out> AS SELECT <cols> <expr>
into_expr: sqlglot.exp.Into = statement.args["into"].pop()
into_table = into_expr.this

create = sqlglot.exp.Create(
this=into_table,
kind="TABLE",
expression=statement,
)
return create


def _sqlglot_lineage_inner(
sql: sqlglot.exp.ExpOrStr,
schema_resolver: SchemaResolverInterface,
Expand All @@ -885,12 +928,9 @@ def _sqlglot_lineage_inner(
else:
dialect = get_dialect(default_dialect)

if is_dialect_instance(dialect, "snowflake"):
# in snowflake, table identifiers must be uppercased to match sqlglot's behavior.
if default_db:
default_db = default_db.upper()
if default_schema:
default_schema = default_schema.upper()
default_db = _normalize_db_or_schema(default_db, dialect)
default_schema = _normalize_db_or_schema(default_schema, dialect)

if is_dialect_instance(dialect, "redshift") and not default_schema:
# On Redshift, there's no "USE SCHEMA <schema>" command. The default schema
# is public, and "current schema" is the one at the front of the search path.
Expand Down Expand Up @@ -918,6 +958,8 @@ def _sqlglot_lineage_inner(
# original_statement.sql(pretty=True, dialect=dialect),
# )

statement = _simplify_select_into(statement)

# Make sure the tables are resolved with the default db / schema.
# This only works for Unionable statements. For other types of statements,
# we have to do it manually afterwards, but that's slightly lower accuracy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def is_dialect_instance(
else:
platforms = list(platforms)

dialects = [sqlglot.Dialect.get_or_raise(platform) for platform in platforms]
dialects = [get_dialect(platform) for platform in platforms]

if any(isinstance(dialect, dialect_class.__class__) for dialect_class in dialects):
return True
Expand Down
26 changes: 26 additions & 0 deletions metadata-ingestion/tests/integration/powerbi/golden_test_cll.json
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,32 @@
"dataset": "urn:li:dataset:(urn:li:dataPlatform:mssql,commopsdb.dbo.v_ps_cd_retention,PROD)",
"type": "TRANSFORMED"
}
],
"fineGrainedLineages": [
{
"upstreamType": "FIELD_SET",
"upstreams": [
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:mssql,commopsdb.dbo.v_ps_cd_retention,PROD),client_director)",
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:mssql,commopsdb.dbo.v_ps_cd_retention,PROD),month_wid)"
],
"downstreamType": "FIELD",
"downstreams": [
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:powerbi,hr_pbi_test.ms_sql_native_table,DEV),cd_agent_key)"
],
"confidenceScore": 1.0
},
{
"upstreamType": "FIELD_SET",
"upstreams": [
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:mssql,commopsdb.dbo.v_ps_cd_retention,PROD),client_manager_closing_month)",
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:mssql,commopsdb.dbo.v_ps_cd_retention,PROD),month_wid)"
],
"downstreamType": "FIELD",
"downstreams": [
"urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:powerbi,hr_pbi_test.ms_sql_native_table,DEV),agent_key)"
],
"confidenceScore": 1.0
}
]
}
},
Expand Down
Loading

0 comments on commit 1bfd4ee

Please sign in to comment.