Skip to content

Commit

Permalink
feat(ingest/athena): handle partition fetching errors (datahub-projec…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Nov 29, 2024
1 parent a92c6b2 commit a46de1e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 25 deletions.
68 changes: 46 additions & 22 deletions metadata-ingestion/src/datahub/ingestion/source/sql/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import StructuredLogLevel
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.s3_util import make_s3_urn
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
Expand All @@ -35,6 +36,7 @@
register_custom_type,
)
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig, make_sqlalchemy_uri
from datahub.ingestion.source.sql.sql_report import SQLSourceReport
from datahub.ingestion.source.sql.sql_utils import (
add_table_to_schema_container,
gen_database_container,
Expand All @@ -48,6 +50,15 @@
get_schema_fields_for_sqlalchemy_column,
)

try:
from typing_extensions import override
except ImportError:
_F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any])

def override(f: _F, /) -> _F: # noqa: F811
return f


logger = logging.getLogger(__name__)

assert STRUCT, "required type modules are not available"
Expand Down Expand Up @@ -322,12 +333,15 @@ class AthenaSource(SQLAlchemySource):
- Profiling when enabled.
"""

table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {}
config: AthenaConfig
report: SQLSourceReport

def __init__(self, config, ctx):
super().__init__(config, ctx, "athena")
self.cursor: Optional[BaseCursor] = None

self.table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {}

@classmethod
def create(cls, config_dict, ctx):
config = AthenaConfig.parse_obj(config_dict)
Expand Down Expand Up @@ -452,41 +466,50 @@ def add_table_to_schema_container(
)

# It seems like database/schema filter in the connection string does not work and this to work around that
@override
def get_schema_names(self, inspector: Inspector) -> List[str]:
athena_config = typing.cast(AthenaConfig, self.config)
schemas = inspector.get_schema_names()
if athena_config.database:
return [schema for schema in schemas if schema == athena_config.database]
return schemas

# Overwrite to get partitions
@classmethod
def _casted_partition_key(cls, key: str) -> str:
# We need to cast the partition keys to a VARCHAR, since otherwise
# Athena may throw an error during concatenation / comparison.
return f"CAST({key} as VARCHAR)"

@override
def get_partitions(
self, inspector: Inspector, schema: str, table: str
) -> List[str]:
partitions = []

athena_config = typing.cast(AthenaConfig, self.config)

if not athena_config.extract_partitions:
return []
) -> Optional[List[str]]:
if not self.config.extract_partitions:
return None

if not self.cursor:
return []
return None

metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
table_name=table, schema_name=schema
)

if metadata.partition_keys:
for key in metadata.partition_keys:
if key.name:
partitions.append(key.name)

if not partitions:
return []
partitions = []
for key in metadata.partition_keys:
if key.name:
partitions.append(key.name)
if not partitions:
return []

# We create an artiificaial concatenated partition key to be able to query max partition easier
part_concat = "|| '-' ||".join(partitions)
with self.report.report_exc(
message="Failed to extract partition details",
context=f"{schema}.{table}",
level=StructuredLogLevel.WARN,
):
# We create an artifical concatenated partition key to be able to query max partition easier
part_concat = " || '-' || ".join(
self._casted_partition_key(key) for key in partitions
)
max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")'
ret = self.cursor.execute(max_partition_query)
max_partition: Dict[str, str] = {}
Expand All @@ -500,9 +523,8 @@ def get_partitions(
partitions=partitions,
max_partition=max_partition,
)
return partitions

return []
return partitions

# Overwrite to modify the creation of schema fields
def get_schema_fields_for_column(
Expand Down Expand Up @@ -551,7 +573,9 @@ def generate_partition_profiler_query(
if partition and partition.max_partition:
max_partition_filters = []
for key, value in partition.max_partition.items():
max_partition_filters.append(f"CAST({key} as VARCHAR) = '{value}'")
max_partition_filters.append(
f"{self._casted_partition_key(key)} = '{value}'"
)
max_partition = str(partition.max_partition)
return (
max_partition,
Expand Down
45 changes: 42 additions & 3 deletions metadata-ingestion/tests/unit/test_athena_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def test_athena_get_table_properties():
"CreateTime": datetime.now(),
"LastAccessTime": datetime.now(),
"PartitionKeys": [
{"Name": "testKey", "Type": "string", "Comment": "testComment"}
{"Name": "year", "Type": "string", "Comment": "testComment"},
{"Name": "month", "Type": "string", "Comment": "testComment"},
],
"Parameters": {
"comment": "testComment",
Expand All @@ -112,8 +113,18 @@ def test_athena_get_table_properties():
response=table_metadata
)

# Mock partition query results
mock_cursor.execute.return_value.description = [
["year"],
["month"],
]
mock_cursor.execute.return_value.__iter__.return_value = [["2023", "12"]]

ctx = PipelineContext(run_id="test")
source = AthenaSource(config=config, ctx=ctx)
source.cursor = mock_cursor

# Test table properties
description, custom_properties, location = source.get_table_properties(
inspector=mock_inspector, table=table, schema=schema
)
Expand All @@ -124,13 +135,35 @@ def test_athena_get_table_properties():
"last_access_time": "2020-04-14 07:00:00",
"location": "s3://testLocation",
"outputformat": "testOutputFormat",
"partition_keys": '[{"name": "testKey", "type": "string", "comment": "testComment"}]',
"partition_keys": '[{"name": "year", "type": "string", "comment": "testComment"}, {"name": "month", "type": "string", "comment": "testComment"}]',
"serde.serialization.lib": "testSerde",
"table_type": "testType",
}

assert location == make_s3_urn("s3://testLocation", "PROD")

# Test partition functionality
partitions = source.get_partitions(
inspector=mock_inspector, schema=schema, table=table
)
assert partitions == ["year", "month"]

# Verify the correct SQL query was generated for partitions
expected_query = """\
select year,month from "test_schema"."test_table$partitions" \
where CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR) = \
(select max(CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR)) \
from "test_schema"."test_table$partitions")"""
mock_cursor.execute.assert_called_once()
actual_query = mock_cursor.execute.call_args[0][0]
assert actual_query == expected_query

# Verify partition cache was populated correctly
assert source.table_partition_cache[schema][table].partitions == partitions
assert source.table_partition_cache[schema][table].max_partition == {
"year": "2023",
"month": "12",
}


def test_get_column_type_simple_types():
assert isinstance(
Expand Down Expand Up @@ -214,3 +247,9 @@ def test_column_type_complex_combination():
assert isinstance(
result._STRUCT_fields[2][1].item_type._STRUCT_fields[1][1], types.String
)


def test_casted_partition_key():
from datahub.ingestion.source.sql.athena import AthenaSource

assert AthenaSource._casted_partition_key("test_col") == "CAST(test_col as VARCHAR)"

0 comments on commit a46de1e

Please sign in to comment.