From eac67c3fee4275c37cd86353056493ea1d5b1121 Mon Sep 17 00:00:00 2001 From: Anil Menon Date: Thu, 10 Oct 2024 11:35:36 +0200 Subject: [PATCH 1/3] Support for Azure storage for Unity Catalog read_deltalake --- daft/delta_lake/delta_lake_scan.py | 55 +++++++++++++++++------------ daft/unity_catalog/unity_catalog.py | 35 ++++++++++++------ 2 files changed, 56 insertions(+), 34 deletions(-) diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index eb6973f24d..90a3eb8f3b 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -3,6 +3,7 @@ import logging import os from typing import TYPE_CHECKING +from urllib.parse import urlparse from deltalake.table import DeltaTable @@ -37,31 +38,39 @@ def __init__(self, table_uri: str, storage_config: StorageConfig) -> None: # # See: https://github.com/delta-io/delta-rs/issues/2117 deltalake_sdk_io_config = storage_config.config.io_config - if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]): - try: - s3_config_from_env = S3Config.from_env() - # Sometimes S3Config.from_env throws an error, for example on CI machines with weird metadata servers. - except daft.exceptions.DaftCoreException: - pass - else: - if ( - deltalake_sdk_io_config.s3.key_id is None - and deltalake_sdk_io_config.s3.access_key is None - and deltalake_sdk_io_config.s3.session_token is None - ): - deltalake_sdk_io_config = deltalake_sdk_io_config.replace( - s3=deltalake_sdk_io_config.s3.replace( - key_id=s3_config_from_env.key_id, - access_key=s3_config_from_env.access_key, - session_token=s3_config_from_env.session_token, + scheme = urlparse(table_uri).scheme + if scheme == "s3" or scheme == "s3a": + if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]): + try: + s3_config_from_env = S3Config.from_env() + # Sometimes S3Config.from_env throws an error, for example on CI machines with weird metadata servers. + except daft.exceptions.DaftCoreException: + pass + else: + if ( + deltalake_sdk_io_config.s3.key_id is None + and deltalake_sdk_io_config.s3.access_key is None + and deltalake_sdk_io_config.s3.session_token is None + ): + deltalake_sdk_io_config = deltalake_sdk_io_config.replace( + s3=deltalake_sdk_io_config.s3.replace( + key_id=s3_config_from_env.key_id, + access_key=s3_config_from_env.access_key, + session_token=s3_config_from_env.session_token, + ) ) - ) - if deltalake_sdk_io_config.s3.region_name is None: - deltalake_sdk_io_config = deltalake_sdk_io_config.replace( - s3=deltalake_sdk_io_config.s3.replace( - region_name=s3_config_from_env.region_name, + if deltalake_sdk_io_config.s3.region_name is None: + deltalake_sdk_io_config = deltalake_sdk_io_config.replace( + s3=deltalake_sdk_io_config.s3.replace( + region_name=s3_config_from_env.region_name, + ) ) - ) + elif scheme == "gcs" or scheme == "gs": + # TO-DO: Handle any key-value replacements in `io_config` if there are missing elements + pass + elif scheme == "az" or scheme == "abfs" or scheme == "abfss": + # TO-DO: Handle any key-value replacements in `io_config` if there are missing elements + pass self._table = DeltaTable( table_uri, storage_options=io_config_to_storage_options(deltalake_sdk_io_config, table_uri) diff --git a/daft/unity_catalog/unity_catalog.py b/daft/unity_catalog/unity_catalog.py index 9eafaee8dd..daaa9b8ed7 100644 --- a/daft/unity_catalog/unity_catalog.py +++ b/daft/unity_catalog/unity_catalog.py @@ -2,10 +2,11 @@ import dataclasses from typing import Callable +from urllib.parse import urlparse import unitycatalog -from daft.io import IOConfig, S3Config +from daft.io import IOConfig, S3Config, AzureConfig @dataclasses.dataclass(frozen=True) @@ -96,18 +97,30 @@ def load_table(self, table_name: str) -> UnityCatalogTable: # Grab credentials from Unity catalog and place it into the Table temp_table_credentials = self._client.temporary_table_credentials.create(operation="READ", table_id=table_id) - aws_temp_credentials = temp_table_credentials.aws_temp_credentials - io_config = ( - IOConfig( - s3=S3Config( - key_id=aws_temp_credentials.access_key_id, - access_key=aws_temp_credentials.secret_access_key, - session_token=aws_temp_credentials.session_token, + + scheme = urlparse(storage_location).scheme + if scheme == "s3" or scheme == "s3a": + aws_temp_credentials = temp_table_credentials.aws_temp_credentials + io_config = ( + IOConfig( + s3=S3Config( + key_id=aws_temp_credentials.access_key_id, + access_key=aws_temp_credentials.secret_access_key, + session_token=aws_temp_credentials.session_token, + ) ) + if aws_temp_credentials is not None + else None ) - if aws_temp_credentials is not None - else None - ) + elif scheme == "gcs" or scheme == "gs": + # TO-DO: gather GCS credential vending assets from Unity and construct 'io_config`` + pass + elif scheme == "az" or scheme == "abfs" or scheme == "abfss": + io_config = IOConfig( + azure=AzureConfig( + sas_token = temp_table_credentials.azure_user_delegation_sas.get('sas_token') + ) + ) return UnityCatalogTable( table_uri=storage_location, From ffc977680ca054fc9c1cf2521ad591c78968a810 Mon Sep 17 00:00:00 2001 From: Anil Menon Date: Thu, 10 Oct 2024 13:21:18 +0200 Subject: [PATCH 2/3] Python linting of order of imports as specified by ruff --- daft/unity_catalog/unity_catalog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daft/unity_catalog/unity_catalog.py b/daft/unity_catalog/unity_catalog.py index daaa9b8ed7..bb2df9e43b 100644 --- a/daft/unity_catalog/unity_catalog.py +++ b/daft/unity_catalog/unity_catalog.py @@ -6,7 +6,7 @@ import unitycatalog -from daft.io import IOConfig, S3Config, AzureConfig +from daft.io import AzureConfig, IOConfig, S3Config @dataclasses.dataclass(frozen=True) From 27bba268a2c95fd797581b53b6cef6a2473fa67d Mon Sep 17 00:00:00 2001 From: Anil Menon Date: Thu, 10 Oct 2024 13:49:31 +0200 Subject: [PATCH 3/3] Fixing final failures in linting reported on 'pre-commit' checks on local repo --- daft/unity_catalog/unity_catalog.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/daft/unity_catalog/unity_catalog.py b/daft/unity_catalog/unity_catalog.py index bb2df9e43b..640e557627 100644 --- a/daft/unity_catalog/unity_catalog.py +++ b/daft/unity_catalog/unity_catalog.py @@ -97,7 +97,7 @@ def load_table(self, table_name: str) -> UnityCatalogTable: # Grab credentials from Unity catalog and place it into the Table temp_table_credentials = self._client.temporary_table_credentials.create(operation="READ", table_id=table_id) - + scheme = urlparse(storage_location).scheme if scheme == "s3" or scheme == "s3a": aws_temp_credentials = temp_table_credentials.aws_temp_credentials @@ -117,10 +117,8 @@ def load_table(self, table_name: str) -> UnityCatalogTable: pass elif scheme == "az" or scheme == "abfs" or scheme == "abfss": io_config = IOConfig( - azure=AzureConfig( - sas_token = temp_table_credentials.azure_user_delegation_sas.get('sas_token') - ) - ) + azure=AzureConfig(sas_token=temp_table_credentials.azure_user_delegation_sas.get("sas_token")) + ) return UnityCatalogTable( table_uri=storage_location,