Skip to content

Commit

Permalink
supports custom account host for azure blob storage
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Nov 20, 2024
1 parent 1204e83 commit c1a86de
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 62 deletions.
6 changes: 6 additions & 0 deletions dlt/common/configuration/specs/azure_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ class AzureCredentialsWithoutDefaults(CredentialsConfiguration):
azure_storage_sas_token: TSecretStrValue = None
azure_sas_token_permissions: str = "racwdl"
"""Permissions to use when generating a SAS token. Ignored when sas token is provided directly"""
azure_account_host: Optional[str] = None
"""Alternative host when accessing blob storage endpoint ie. my_account.dfs.core.windows.net"""

def to_adlfs_credentials(self) -> Dict[str, Any]:
"""Return a dict that can be passed as kwargs to adlfs"""
return dict(
account_name=self.azure_storage_account_name,
account_key=self.azure_storage_account_key,
sas_token=self.azure_storage_sas_token,
account_host=self.azure_account_host,
)

def to_object_store_rs_credentials(self) -> Dict[str, str]:
Expand Down Expand Up @@ -68,10 +71,13 @@ class AzureServicePrincipalCredentialsWithoutDefaults(CredentialsConfiguration):
azure_tenant_id: str = None
azure_client_id: str = None
azure_client_secret: TSecretStrValue = None
azure_account_host: Optional[str] = None
"""Alternative host when accessing blob storage endpoint ie. my_account.dfs.core.windows.net"""

def to_adlfs_credentials(self) -> Dict[str, Any]:
return dict(
account_name=self.azure_storage_account_name,
account_host=self.azure_account_host,
tenant_id=self.azure_tenant_id,
client_id=self.azure_client_id,
client_secret=self.azure_client_secret,
Expand Down
35 changes: 35 additions & 0 deletions dlt/common/storages/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
BaseConfiguration,
SFTPCredentials,
)
from dlt.common.exceptions import TerminalValueError
from dlt.common.typing import DictStrAny
from dlt.common.utils import digest128

Expand Down Expand Up @@ -57,6 +58,40 @@ class LoadStorageConfiguration(BaseConfiguration):
]


def ensure_canonical_az_url(
bucket_url: str, target_scheme: str, storage_account_name: str = None, account_host: str = None
) -> str:
"""Converts any of the forms of azure blob storage into canonical form of {target_scheme}://<container_name>@<storage_account_name>.{account_host}/<path>
`azure_storage_account_name` is optional only if not present in bucket_url, `account_host` assumes "dfs.core.windows.net" by default
"""
parsed_bucket_url = urlparse(bucket_url)
# Converts an az://<container_name>/<path> to abfss://<container_name>@<storage_account_name>.dfs.core.windows.net/<path>
if parsed_bucket_url.username:
# has the right form, ensure abfss schema
return urlunparse(parsed_bucket_url._replace(scheme=target_scheme))

if not storage_account_name and not account_host:
raise TerminalValueError(
f"Could not convert azure blob storage url {bucket_url} into canonical form "
f" ({target_scheme}://<container_name>@<storage_account_name>.dfs.core.windows.net/<path>)"
f" because storage account name is not known. Please use {target_scheme}:// canonical"
" url as bucket_url in filesystem credentials"
)

account_host = account_host or f"{storage_account_name}.dfs.core.windows.net"

# as required by databricks
_path = parsed_bucket_url.path
return urlunparse(
parsed_bucket_url._replace(
scheme=target_scheme,
netloc=f"{parsed_bucket_url.netloc}@{account_host}",
path=_path,
)
)


def _make_sftp_url(scheme: str, fs_path: str, bucket_url: str) -> str:
parsed_bucket_url = urlparse(bucket_url)
return f"{scheme}://{parsed_bucket_url.hostname}{fs_path}"
Expand Down
20 changes: 17 additions & 3 deletions dlt/common/storages/fsspec_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from dlt.common.time import ensure_pendulum_datetime
from dlt.common.typing import DictStrAny
from dlt.common.utils import without_none


class FileItem(TypedDict, total=False):
Expand Down Expand Up @@ -97,6 +98,10 @@ class FileItem(TypedDict, total=False):
DEFAULT_KWARGS["azure"] = DEFAULT_KWARGS["az"]
DEFAULT_KWARGS["abfss"] = DEFAULT_KWARGS["az"]

AZURE_BLOB_STORAGE_PROTOCOLS = ["az", "azure", "adl", "abfss", "abfs"]
S3_PROTOCOLS = ["s3", "s3a"]
GCS_PROTOCOLS = ["gs", "gcs"]


def fsspec_filesystem(
protocol: str,
Expand Down Expand Up @@ -130,7 +135,11 @@ def prepare_fsspec_args(config: FilesystemConfiguration) -> DictStrAny:
"""
protocol = config.protocol
# never use listing caches
fs_kwargs: DictStrAny = {"use_listings_cache": False, "listings_expiry_time": 60.0}
fs_kwargs: DictStrAny = {
"use_listings_cache": False,
"listings_expiry_time": 60.0,
"skip_instance_cache": True,
}
credentials = CREDENTIALS_DISPATCH.get(protocol, lambda _: {})(config)

if protocol == "gdrive":
Expand All @@ -151,7 +160,7 @@ def prepare_fsspec_args(config: FilesystemConfiguration) -> DictStrAny:
if "client_kwargs" in fs_kwargs and "client_kwargs" in credentials:
fs_kwargs["client_kwargs"].update(credentials.pop("client_kwargs"))

fs_kwargs.update(credentials)
fs_kwargs.update(without_none(credentials))
return fs_kwargs


Expand All @@ -174,8 +183,13 @@ def fsspec_from_config(config: FilesystemConfiguration) -> Tuple[AbstractFileSys
# first get the class to check the protocol
fs_cls = get_filesystem_class(config.protocol)
if fs_cls.protocol == "abfs":
url = urlparse(config.bucket_url)
# if storage account is present in bucket_url and in credentials, az fsspec will fail
if urlparse(config.bucket_url).username:
# account name is detected only for blob.core.windows.net host
if url.username and (
url.hostname.endswith("blob.core.windows.net")
or url.hostname.endswith("dfs.core.windows.net")
):
fs_kwargs.pop("account_name")
return url_to_fs(config.bucket_url, **fs_kwargs) # type: ignore
except ImportError as e:
Expand Down
10 changes: 7 additions & 3 deletions dlt/destinations/impl/clickhouse/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
)
from dlt.common.schema.utils import is_nullable_column
from dlt.common.storages import FileStorage
from dlt.common.storages.configuration import FilesystemConfiguration
from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url
from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS
from dlt.destinations.exceptions import LoadJobTerminalException
from dlt.destinations.impl.clickhouse.configuration import (
ClickHouseClientConfiguration,
Expand Down Expand Up @@ -140,7 +141,7 @@ def run(self) -> None:
f"s3('{bucket_http_url}',{auth},'{clickhouse_format}','auto','{compression}')"
)

elif bucket_scheme in ("az", "abfs"):
elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS:
if not isinstance(self._staging_credentials, AzureCredentialsWithoutDefaults):
raise LoadJobTerminalException(
self._file_path,
Expand All @@ -149,7 +150,10 @@ def run(self) -> None:

# Authenticated access.
account_name = self._staging_credentials.azure_storage_account_name
storage_account_url = f"https://{self._staging_credentials.azure_storage_account_name}.blob.core.windows.net"
account_host = self._staging_credentials.azure_account_host
storage_account_url = ensure_canonical_az_url(
bucket_path, "https", account_name, account_host
)
account_key = self._staging_credentials.azure_storage_account_key

# build table func
Expand Down
45 changes: 17 additions & 28 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
AzureCredentialsWithoutDefaults,
)
from dlt.common.exceptions import TerminalValueError
from dlt.common.storages.configuration import ensure_canonical_az_url
from dlt.common.storages.file_storage import FileStorage
from dlt.common.storages.fsspec_filesystem import (
AZURE_BLOB_STORAGE_PROTOCOLS,
S3_PROTOCOLS,
GCS_PROTOCOLS,
)
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema.typing import TColumnType
from dlt.common.storages import FilesystemConfiguration, fsspec_from_config


from dlt.destinations.insert_job_client import InsertValuesJobClient
from dlt.destinations.exceptions import LoadJobTerminalException
from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration
Expand All @@ -32,8 +37,8 @@
from dlt.destinations.job_impl import ReferenceFollowupJobRequest
from dlt.destinations.utils import is_compression_disabled

AZURE_BLOB_STORAGE_PROTOCOLS = ["az", "abfss", "abfs"]
SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + ["s3", "gs", "gcs"]

SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS + GCS_PROTOCOLS


class DatabricksLoadJob(RunnableLoadJob, HasFollowupJobs):
Expand Down Expand Up @@ -106,7 +111,9 @@ def run(self) -> None:
# Explicit azure credentials are needed to load from bucket without a named stage
credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))"""
bucket_path = self.ensure_databricks_abfss_url(
bucket_path, staging_credentials.azure_storage_account_name
bucket_path,
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)
else:
raise LoadJobTerminalException(
Expand All @@ -124,7 +131,9 @@ def run(self) -> None:
),
)
bucket_path = self.ensure_databricks_abfss_url(
bucket_path, staging_credentials.azure_storage_account_name
bucket_path,
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)

# always add FROM clause
Expand Down Expand Up @@ -165,30 +174,10 @@ def run(self) -> None:

@staticmethod
def ensure_databricks_abfss_url(
bucket_path: str, azure_storage_account_name: str = None
bucket_path: str, azure_storage_account_name: str = None, account_host: str = None
) -> str:
bucket_url = urlparse(bucket_path)
# Converts an az://<container_name>/<path> to abfss://<container_name>@<storage_account_name>.dfs.core.windows.net/<path>
if bucket_url.username:
# has the right form, ensure abfss schema
return urlunparse(bucket_url._replace(scheme="abfss"))

if not azure_storage_account_name:
raise TerminalValueError(
f"Could not convert azure blob storage url {bucket_path} into form required by"
" Databricks"
" (abfss://<container_name>@<storage_account_name>.dfs.core.windows.net/<path>)"
" because storage account name is not known. Please use Databricks abfss://"
" canonical url as bucket_url in staging credentials"
)
# as required by databricks
_path = bucket_url.path
return urlunparse(
bucket_url._replace(
scheme="abfss",
netloc=f"{bucket_url.netloc}@{azure_storage_account_name}.dfs.core.windows.net",
path=_path,
)
return ensure_canonical_az_url(
bucket_path, "abfss", azure_storage_account_name, account_host
)


Expand Down
23 changes: 10 additions & 13 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
AwsCredentialsWithoutDefaults,
AzureCredentialsWithoutDefaults,
)
from dlt.common.storages.configuration import FilesystemConfiguration
from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema.typing import TColumnType
from dlt.common.exceptions import TerminalValueError

from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS
from dlt.common.typing import TLoaderFileFormat
from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset
from dlt.destinations.exceptions import LoadJobTerminalException
Expand Down Expand Up @@ -124,33 +125,29 @@ def gen_copy_sql(
if not is_local:
bucket_scheme = parsed_file_url.scheme
# referencing an external s3/azure stage does not require explicit AWS credentials
if bucket_scheme in ["s3", "az", "abfs"] and stage_name:
if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS and stage_name:
from_clause = f"FROM '@{stage_name}'"
files_clause = f"FILES = ('{parsed_file_url.path.lstrip('/')}')"
# referencing an staged files via a bucket URL requires explicit AWS credentials
elif (
bucket_scheme == "s3"
bucket_scheme in S3_PROTOCOLS
and staging_credentials
and isinstance(staging_credentials, AwsCredentialsWithoutDefaults)
):
credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')"""
from_clause = f"FROM '{file_url}'"
elif (
bucket_scheme in ["az", "abfs"]
bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS
and staging_credentials
and isinstance(staging_credentials, AzureCredentialsWithoutDefaults)
):
# Explicit azure credentials are needed to load from bucket without a named stage
credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')"
# Converts an az://<container_name>/<path> to azure://<storage_account_name>.blob.core.windows.net/<container_name>/<path>
# as required by snowflake
_path = "/" + parsed_file_url.netloc + parsed_file_url.path
file_url = urlunparse(
parsed_file_url._replace(
scheme="azure",
netloc=f"{staging_credentials.azure_storage_account_name}.blob.core.windows.net",
path=_path,
)
file_url = ensure_canonical_az_url(
file_url,
"azure",
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)
from_clause = f"FROM '{file_url}'"
else:
Expand Down
Loading

0 comments on commit c1a86de

Please sign in to comment.