diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index a085c9a8..5dfc55f7 100644 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -3,13 +3,14 @@ import re import boto3.session from botocore.exceptions import ClientError -from typing import Optional +from typing import Optional, List -from dbt.adapters.base import available +from dbt.adapters.base import available, Column from dbt.adapters.sql import SQLAdapter from dbt.adapters.athena import AthenaConnectionManager from dbt.adapters.athena.relation import AthenaRelation from dbt.events import AdapterLogger +from dbt.contracts.relation import RelationType logger = AdapterLogger("Athena") class AthenaAdapter(SQLAdapter): @@ -170,3 +171,64 @@ def quote_seed_column( self, column: str, quote_config: Optional[bool] ) -> str: return super().quote_seed_column(column, False) + + def get_columns_in_relation(self, relation: AthenaRelation) -> List[Column]: + conn = self.connections.get_thread_connection() + creds = conn.credentials + session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name) + glue_client = session.client('glue') + + table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier) + return [Column(c["Name"], c["Type"]) for c in table["Table"]["StorageDescriptor"]["Columns"] + table["Table"]["PartitionKeys"]] + + def list_schemas(self, database: str) -> List[str]: + conn = self.connections.get_thread_connection() + creds = conn.credentials + session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name) + glue_client = session.client('glue') + paginator = glue_client.get_paginator("get_databases") + + result = [] + logger.debug("CALL glue.get_databases()") + for page in paginator.paginate(): + for db in page["DatabaseList"]: + result.append(db["Name"]) + return result + + def list_relations_without_caching(self, schema_relation: AthenaRelation) -> List[AthenaRelation]: + conn = self.connections.get_thread_connection() + creds = conn.credentials + session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name) + glue_client = session.client('glue') + paginator = glue_client.get_paginator("get_tables") + + result = [] + logger.debug("CALL glue.get_tables('{}')", schema_relation.schema) + for page in paginator.paginate(DatabaseName=schema_relation.schema): + for table in page["TableList"]: + if table["TableType"] == "EXTERNAL_TABLE": + table_type = RelationType.Table + elif table["TableType"] == "VIRTUAL_VIEW": + table_type = RelationType.View + else: + raise ValueError(f"Unknown TableType for {table['Name']}: {table['TableType']}") + rel = AthenaRelation.create(schema=table["DatabaseName"], identifier=table["Name"], database=schema_relation.database, type=table_type) + result.append(rel) + + return result + + @available + def delete_table(self, relation: AthenaRelation): + conn = self.connections.get_thread_connection() + creds = conn.credentials + session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name) + glue_client = session.client('glue') + + logger.debug("CALL glue.delete_table({}, {})", relation.schema, relation.identifier) + try: + glue_client.delete_table(DatabaseName=relation.schema, Name=relation.identifier) + except ClientError as e: + if e.response['Error']['Code'] == 'EntityNotFoundException': + logger.debug("Table '{}' does not exists - Ignoring", relation) + else: + raise diff --git a/dbt/include/athena/macros/adapters/columns.sql b/dbt/include/athena/macros/adapters/columns.sql index 2c62f4d7..1d7d742d 100644 --- a/dbt/include/athena/macros/adapters/columns.sql +++ b/dbt/include/athena/macros/adapters/columns.sql @@ -1,22 +1,3 @@ {% macro athena__get_columns_in_relation(relation) -%} - {% call statement('get_columns_in_relation', fetch_result=True) %} - - select - column_name, - data_type, - null as character_maximum_length, - null as numeric_precision, - null as numeric_scale - - from {{ relation.information_schema('columns') }} - where LOWER(table_name) = LOWER('{{ relation.identifier }}') - {% if relation.schema %} - and LOWER(table_schema) = LOWER('{{ relation.schema }}') - {% endif %} - order by ordinal_position - - {% endcall %} - - {% set table = load_result('get_columns_in_relation').table %} - {% do return(sql_convert_columns_in_relation(table)) %} + {{ return(adapter.get_columns_in_relation(relation)) }} {% endmacro %} diff --git a/dbt/include/athena/macros/adapters/metadata.sql b/dbt/include/athena/macros/adapters/metadata.sql index a6e9f1c9..22aeb828 100644 --- a/dbt/include/athena/macros/adapters/metadata.sql +++ b/dbt/include/athena/macros/adapters/metadata.sql @@ -79,42 +79,10 @@ {% macro athena__list_schemas(database) -%} - {% call statement('list_schemas', fetch_result=True) %} - select - distinct schema_name - - from {{ information_schema_name(database) }}.schemata - {% endcall %} - {{ return(load_result('list_schemas').table) }} + {{ return(adapter.list_schemas()) }} {% endmacro %} {% macro athena__list_relations_without_caching(schema_relation) %} - {% call statement('list_relations_without_caching', fetch_result=True) -%} - WITH views AS ( - select - table_catalog as database, - table_name as name, - table_schema as schema - from {{ schema_relation.information_schema() }}.views - where LOWER(table_schema) = LOWER('{{ schema_relation.schema }}') - ), tables AS ( - select - table_catalog as database, - table_name as name, - table_schema as schema - - from {{ schema_relation.information_schema() }}.tables - where LOWER(table_schema) = LOWER('{{ schema_relation.schema }}') - - -- Views appear in both `tables` and `views`, so excluding them from tables - EXCEPT - - select * from views - ) - select views.*, 'view' AS table_type FROM views - UNION ALL - select tables.*, 'table' AS table_type FROM tables - {% endcall %} - {% do return(load_result('list_relations_without_caching').table) %} + {{ return(adapter.list_relations_without_caching(schema_relation)) }} {% endmacro %} diff --git a/dbt/include/athena/macros/adapters/relation.sql b/dbt/include/athena/macros/adapters/relation.sql index 7bb26be9..7e49aee8 100644 --- a/dbt/include/athena/macros/adapters/relation.sql +++ b/dbt/include/athena/macros/adapters/relation.sql @@ -1,6 +1,4 @@ {% macro athena__drop_relation(relation) -%} {%- do adapter.clean_up_table(relation.schema, relation.table) -%} - {% call statement('drop_relation', auto_begin=False) -%} - drop {{ relation.type }} if exists {{ relation }} - {%- endcall %} + {%- do adapter.delete_table(relation) -%} {% endmacro %}