Skip to content

Commit

Permalink
Merge branch 'devel' of https://github.com/dlt-hub/dlt into feat/1996…
Browse files Browse the repository at this point in the history
…-iceberg-filesystem
  • Loading branch information
jorritsandbrink committed Nov 24, 2024
2 parents d39d58d + bfd0b52 commit 2f910c2
Show file tree
Hide file tree
Showing 47 changed files with 1,567 additions and 273 deletions.
41 changes: 22 additions & 19 deletions dlt/common/configuration/specs/azure_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,32 @@
)
from dlt.common.configuration.specs.mixins import WithPyicebergConfig
from dlt import version
from dlt.common.utils import without_none

_AZURE_STORAGE_EXTRA = f"{version.DLT_PKG_NAME}[az]"


@configspec
class AzureCredentialsWithoutDefaults(CredentialsConfiguration, WithPyicebergConfig):
class AzureCredentialsBase(CredentialsConfiguration):
azure_storage_account_name: str = None
azure_account_host: Optional[str] = None
"""Alternative host when accessing blob storage endpoint ie. my_account.dfs.core.windows.net"""

def to_adlfs_credentials(self) -> Dict[str, Any]:
pass

def to_object_store_rs_credentials(self) -> Dict[str, str]:
# https://docs.rs/object_store/latest/object_store/azure
creds: Dict[str, Any] = without_none(self.to_adlfs_credentials()) # type: ignore[assignment]
# only string options accepted
creds.pop("anon", None)
return creds


@configspec
class AzureCredentialsWithoutDefaults(AzureCredentialsBase, WithPyicebergConfig):
"""Credentials for Azure Blob Storage, compatible with adlfs"""

azure_storage_account_name: str = None
azure_storage_account_key: Optional[TSecretStrValue] = None
azure_storage_sas_token: TSecretStrValue = None
azure_sas_token_permissions: str = "racwdl"
Expand All @@ -30,17 +47,9 @@ def to_adlfs_credentials(self) -> Dict[str, Any]:
account_name=self.azure_storage_account_name,
account_key=self.azure_storage_account_key,
sas_token=self.azure_storage_sas_token,
account_host=self.azure_account_host,
)

def to_object_store_rs_credentials(self) -> Dict[str, str]:
# https://docs.rs/object_store/latest/object_store/azure
creds = self.to_adlfs_credentials()
if creds["sas_token"] is None:
creds.pop("sas_token")
if creds["account_key"] is None:
creds.pop("account_key")
return creds

def to_pyiceberg_fileio_config(self) -> Dict[str, Any]:
return {
"adlfs.account-name": self.azure_storage_account_name,
Expand Down Expand Up @@ -71,26 +80,20 @@ def on_partial(self) -> None:


@configspec
class AzureServicePrincipalCredentialsWithoutDefaults(
CredentialsConfiguration, WithPyicebergConfig
):
azure_storage_account_name: str = None
class AzureServicePrincipalCredentialsWithoutDefaults(AzureCredentialsBase, WithPyicebergConfig):
azure_tenant_id: str = None
azure_client_id: str = None
azure_client_secret: TSecretStrValue = None

def to_adlfs_credentials(self) -> Dict[str, Any]:
return dict(
account_name=self.azure_storage_account_name,
account_host=self.azure_account_host,
tenant_id=self.azure_tenant_id,
client_id=self.azure_client_id,
client_secret=self.azure_client_secret,
)

def to_object_store_rs_credentials(self) -> Dict[str, str]:
# https://docs.rs/object_store/latest/object_store/azure
return self.to_adlfs_credentials()

def to_pyiceberg_fileio_config(self) -> Dict[str, Any]:
return {
"adlfs.account-name": self.azure_storage_account_name,
Expand Down
15 changes: 12 additions & 3 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,16 @@
try:
from dlt.common.libs.pandas import DataFrame
from dlt.common.libs.pyarrow import Table as ArrowTable
from dlt.common.libs.ibis import BaseBackend as IbisBackend
except MissingDependencyException:
DataFrame = Any
ArrowTable = Any
IbisBackend = Any

else:
DataFrame = Any
ArrowTable = Any
IbisBackend = Any


class StorageSchemaInfo(NamedTuple):
Expand Down Expand Up @@ -291,7 +295,6 @@ def _make_dataset_name(self, schema_name: str) -> str:
# if default schema is None then suffix is not added
if self.default_schema_name is not None and schema_name != self.default_schema_name:
return (self.dataset_name or "") + "_" + schema_name

return self.dataset_name


Expand Down Expand Up @@ -443,8 +446,9 @@ def run_managed(
self._finished_at = pendulum.now()
# sanity check
assert self._state in ("completed", "retry", "failed")
# wake up waiting threads
signals.wake_all()
if self._state != "retry":
# wake up waiting threads
signals.wake_all()

@abstractmethod
def run(self) -> None:
Expand Down Expand Up @@ -574,12 +578,17 @@ def close(self) -> None: ...
class SupportsReadableDataset(Protocol):
"""A readable dataset retrieved from a destination, has support for creating readable relations for a query or table"""

@property
def schema(self) -> Schema: ...

def __call__(self, query: Any) -> SupportsReadableRelation: ...

def __getitem__(self, table: str) -> SupportsReadableRelation: ...

def __getattr__(self, table: str) -> SupportsReadableRelation: ...

def ibis(self) -> IbisBackend: ...


class JobClientBase(ABC):
def __init__(
Expand Down
47 changes: 46 additions & 1 deletion dlt/common/jsonpath.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Union, List, Any
from typing import Iterable, Union, List, Any, Optional, cast
from itertools import chain

from dlt.common.typing import DictStrAny
Expand Down Expand Up @@ -46,3 +46,48 @@ def resolve_paths(paths: TAnyJsonPath, data: DictStrAny) -> List[str]:
paths = compile_paths(paths)
p: JSONPath
return list(chain.from_iterable((str(r.full_path) for r in p.find(data)) for p in paths))


def is_simple_field_path(path: JSONPath) -> bool:
"""Checks if the given path represents a simple single field name.
Example:
>>> is_simple_field_path(compile_path('id'))
True
>>> is_simple_field_path(compile_path('$.id'))
False
"""
return isinstance(path, JSONPathFields) and len(path.fields) == 1 and path.fields[0] != "*"


def extract_simple_field_name(path: Union[str, JSONPath]) -> Optional[str]:
"""
Extracts a simple field name from a JSONPath if it represents a single field access.
Returns None if the path is complex (contains wildcards, array indices, or multiple fields).
Args:
path: A JSONPath object or string
Returns:
Optional[str]: The field name if path represents a simple field access, None otherwise
Example:
>>> extract_simple_field_name('name')
'name'
>>> extract_simple_field_name('"name"')
'name'
>>> extract_simple_field_name('"na$me"') # Escaped characters are preserved
'na$me'
>>> extract_simple_field_name('"na.me"') # Escaped characters are preserved
'na.me'
>>> extract_simple_field_name('$.name') # Returns None
>>> extract_simple_field_name('$.items[*].name') # Returns None
>>> extract_simple_field_name('*') # Returns None
"""
if isinstance(path, str):
path = compile_path(path)

if is_simple_field_path(path):
return cast(str, path.fields[0])

return None
121 changes: 121 additions & 0 deletions dlt/common/libs/ibis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import cast

from dlt.common.exceptions import MissingDependencyException

from dlt.common.destination.reference import TDestinationReferenceArg, Destination, JobClientBase

try:
import ibis # type: ignore
from ibis import BaseBackend
except ModuleNotFoundError:
raise MissingDependencyException("dlt ibis Helpers", ["ibis"])


SUPPORTED_DESTINATIONS = [
"dlt.destinations.postgres",
"dlt.destinations.duckdb",
"dlt.destinations.motherduck",
"dlt.destinations.filesystem",
"dlt.destinations.bigquery",
"dlt.destinations.snowflake",
"dlt.destinations.redshift",
"dlt.destinations.mssql",
"dlt.destinations.synapse",
"dlt.destinations.clickhouse",
# NOTE: Athena could theoretically work with trino backend, but according to
# https://github.com/ibis-project/ibis/issues/7682 connecting with aws credentials
# does not work yet.
# "dlt.destinations.athena",
]


def create_ibis_backend(
destination: TDestinationReferenceArg, client: JobClientBase
) -> BaseBackend:
"""Create a given ibis backend for a destination client and dataset"""

# check if destination is supported
destination_type = Destination.from_reference(destination).destination_type
if destination_type not in SUPPORTED_DESTINATIONS:
raise NotImplementedError(f"Destination of type {destination_type} not supported by ibis.")

if destination_type in ["dlt.destinations.motherduck", "dlt.destinations.duckdb"]:
import duckdb
from dlt.destinations.impl.duckdb.duck import DuckDbClient

duck_client = cast(DuckDbClient, client)
duck = duckdb.connect(
database=duck_client.config.credentials._conn_str(),
read_only=duck_client.config.credentials.read_only,
config=duck_client.config.credentials._get_conn_config(),
)
con = ibis.duckdb.from_connection(duck)
elif destination_type in [
"dlt.destinations.postgres",
"dlt.destinations.redshift",
]:
credentials = client.config.credentials.to_native_representation()
con = ibis.connect(credentials)
elif destination_type == "dlt.destinations.snowflake":
from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient

sf_client = cast(SnowflakeClient, client)
credentials = sf_client.config.credentials.to_connector_params()
con = ibis.snowflake.connect(**credentials)
elif destination_type in ["dlt.destinations.mssql", "dlt.destinations.synapse"]:
from dlt.destinations.impl.mssql.mssql import MsSqlJobClient

mssql_client = cast(MsSqlJobClient, client)
con = ibis.mssql.connect(
host=mssql_client.config.credentials.host,
port=mssql_client.config.credentials.port,
database=mssql_client.config.credentials.database,
user=mssql_client.config.credentials.username,
password=mssql_client.config.credentials.password,
driver=mssql_client.config.credentials.driver,
)
elif destination_type == "dlt.destinations.bigquery":
from dlt.destinations.impl.bigquery.bigquery import BigQueryClient

bq_client = cast(BigQueryClient, client)
credentials = bq_client.config.credentials.to_native_credentials()
con = ibis.bigquery.connect(
credentials=credentials,
project_id=bq_client.sql_client.project_id,
location=bq_client.sql_client.location,
)
elif destination_type == "dlt.destinations.clickhouse":
from dlt.destinations.impl.clickhouse.clickhouse import ClickHouseClient

ch_client = cast(ClickHouseClient, client)
con = ibis.clickhouse.connect(
host=ch_client.config.credentials.host,
port=ch_client.config.credentials.http_port,
database=ch_client.config.credentials.database,
user=ch_client.config.credentials.username,
password=ch_client.config.credentials.password,
secure=bool(ch_client.config.credentials.secure),
# compression=True,
)
elif destination_type == "dlt.destinations.filesystem":
import duckdb
from dlt.destinations.impl.filesystem.sql_client import (
FilesystemClient,
FilesystemSqlClient,
)
from dlt.destinations.impl.duckdb.factory import DuckDbCredentials

# we create an in memory duckdb and create all tables on there
duck = duckdb.connect(":memory:")
fs_client = cast(FilesystemClient, client)
creds = DuckDbCredentials(duck)
sql_client = FilesystemSqlClient(
fs_client, dataset_name=fs_client.dataset_name, credentials=creds
)

# NOTE: we should probably have the option for the user to only select a subset of tables here
with sql_client as _:
sql_client.create_views_for_all_tables()
con = ibis.duckdb.from_connection(duck)

return con
4 changes: 2 additions & 2 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def row_tuples_to_arrow(
pivoted_rows = np.asarray(rows, dtype="object", order="k").T # type: ignore[call-overload]

columnar = {
col: dat.ravel() for col, dat in zip(columns, np.vsplit(pivoted_rows, len(columns)))
col: dat.ravel() for col, dat in zip(columns, np.vsplit(pivoted_rows, len(pivoted_rows)))
}
columnar_known_types = {
col["name"]: columnar[col["name"]]
Expand Down Expand Up @@ -669,7 +669,7 @@ def row_tuples_to_arrow(
pa.field(
key,
arrow_col.type,
nullable=columns[key]["nullable"],
nullable=columns[key].get("nullable", True),
)
)

Expand Down
22 changes: 21 additions & 1 deletion dlt/common/libs/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
try:
from sqlalchemy import MetaData, Table, Column, create_engine
from sqlalchemy.engine import Engine, URL, make_url, Row
from sqlalchemy.sql import sqltypes, Select
from sqlalchemy.sql import sqltypes, Select, Executable
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.sqltypes import TypeEngine
from sqlalchemy.exc import CompileError
import sqlalchemy as sa
Expand All @@ -18,3 +19,22 @@

# TODO: maybe use sa.__version__?
IS_SQL_ALCHEMY_20 = hasattr(sa, "Double")

__all__ = [
"IS_SQL_ALCHEMY_20",
"MetaData",
"Table",
"Column",
"create_engine",
"Engine",
"URL",
"make_url",
"Row",
"sqltypes",
"Select",
"Executable",
"TextClause",
"TypeEngine",
"CompileError",
"sa",
]
Loading

0 comments on commit 2f910c2

Please sign in to comment.