Skip to content

Commit

Permalink
integrate ibis relation into existing code
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Nov 25, 2024
1 parent 34323da commit ff330b6
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 446 deletions.
58 changes: 55 additions & 3 deletions dlt/common/libs/ibis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import cast
from typing import cast, Any

from dlt.common.exceptions import MissingDependencyException

from dlt.common.destination.reference import TDestinationReferenceArg, Destination, JobClientBase
from dlt.common.schema import Schema
from dlt.destinations.sql_client import SqlClientBase

try:
import ibis # type: ignore
from ibis import BaseBackend
import sqlglot
from ibis import BaseBackend, Expr
except ModuleNotFoundError:
raise MissingDependencyException("dlt ibis Helpers", ["ibis"])

Expand All @@ -29,6 +31,22 @@
]


# Map dlt data types to ibis data types
DATA_TYPE_MAP = {
"text": "string",
"double": "float64",
"bool": "boolean",
"timestamp": "timestamp",
"bigint": "int64",
"binary": "binary",
"json": "string", # Store JSON as string in ibis
"decimal": "decimal",
"wei": "int64", # Wei is a large integer
"date": "date",
"time": "time",
}


def create_ibis_backend(
destination: TDestinationReferenceArg, client: JobClientBase
) -> BaseBackend:
Expand Down Expand Up @@ -119,3 +137,37 @@ def create_ibis_backend(
con = ibis.duckdb.from_connection(duck)

return con


def create_unbound_ibis_table(
sql_client: SqlClientBase[Any], schema: Schema, table_name: str
) -> Expr:
"""Create an unbound ibis table from a dlt schema"""

if table_name not in schema.tables:
raise Exception(
f"Table {table_name} not found in schema. Available tables: {schema.tables.keys()}"
)
table_schema = schema.tables[table_name]

# Convert dlt table schema columns to ibis schema
ibis_schema = {
sql_client.capabilities.casefold_identifier(col_name): DATA_TYPE_MAP[
col_info.get("data_type", "string")
]
for col_name, col_info in table_schema.get("columns", {}).items()
}

# normalize table name
table_path = sql_client.make_qualified_table_name_path(table_name, escape=False)

catalog = None
if len(table_path) == 3:
catalog, database, table = table_path
else:
database, table = table_path

# create unbound ibis table and return in dlt wrapper
unbound_table = ibis.table(schema=ibis_schema, name=table, database=database, catalog=catalog)

return unbound_table
12 changes: 11 additions & 1 deletion dlt/destinations/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
WithStateSync,
)

from dlt.common.schema.typing import TTableSchemaColumns
from dlt.destinations.sql_client import SqlClientBase, WithSqlClient
from dlt.common.schema import Schema
from dlt.destinations.dataset.relation import ReadableDBAPIRelation
from dlt.destinations.dataset.utils import get_destination_clients
from dlt.destinations.dataset.ibis_relation import ReadableIbisRelation

if TYPE_CHECKING:
try:
Expand All @@ -27,6 +27,11 @@
else:
IbisBackend = Any

try:
from dlt.common.libs.ibis import create_unbound_ibis_table
except MissingDependencyException:
create_unbound_ibis_table = None


class ReadableDBAPIDataset(SupportsReadableDataset):
"""Access to dataframes and arrowtables in the destination dataset via dbapi"""
Expand Down Expand Up @@ -112,6 +117,11 @@ def __call__(self, query: Any) -> ReadableDBAPIRelation:
return ReadableDBAPIRelation(readable_dataset=self, provided_query=query) # type: ignore[abstract]

def table(self, table_name: str) -> SupportsReadableRelation:
# we can create an ibis powered relation if ibis is available
if create_unbound_ibis_table:
unbound_table = create_unbound_ibis_table(self.sql_client, self.schema, table_name)
return ReadableIbisRelation(readable_dataset=self, expression=unbound_table) # type: ignore[abstract]

return ReadableDBAPIRelation(
readable_dataset=self,
table_name=table_name,
Expand Down
161 changes: 161 additions & 0 deletions dlt/destinations/dataset/ibis_relation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import TYPE_CHECKING, Any, Union, Sequence

from functools import partial

from dlt.common.exceptions import MissingDependencyException
from dlt.destinations.dataset.relation import BaseReadableDBAPIRelation
from dlt.common.schema.typing import TTableSchemaColumns


if TYPE_CHECKING:
from dlt.destinations.dataset.dataset import ReadableDBAPIDataset
else:
ReadableDBAPIDataset = Any

try:
from dlt.common.libs.ibis import Expr
except MissingDependencyException:
Expr = Any

# map dlt destination to sqlglot dialect
DIALECT_MAP = {
"dlt.destinations.duckdb": "duckdb", # works
"dlt.destinations.motherduck": "duckdb", # works
"dlt.destinations.clickhouse": "clickhouse", # works
"dlt.destinations.databricks": "databricks", # works
"dlt.destinations.bigquery": "bigquery", # works
"dlt.destinations.postgres": "postgres", # works
"dlt.destinations.redshift": "redshift", # works
"dlt.destinations.snowflake": "snowflake", # works
"dlt.destinations.mssql": "tsql", # works
"dlt.destinations.synapse": "tsql", # works
"dlt.destinations.athena": "trino", # works
"dlt.destinations.filesystem": "duckdb", # works
"dlt.destinations.dremio": "presto", # works
# NOTE: can we discover the current dialect in sqlalchemy?
"dlt.destinations.sqlalchemy": "mysql", # may work
}

# NOTE: some dialects are not supported by ibis, but by sqlglot, these need to
# be transpiled with a intermediary step
TRANSPILE_VIA_MAP = {
"tsql": "postgres",
"databricks": "postgres",
"clickhouse": "postgres",
"redshift": "postgres",
"presto": "postgres",
}


# TODO: provide ibis expression typing for the readable relation
class ReadableIbisRelation(BaseReadableDBAPIRelation):
def __init__(
self,
*,
readable_dataset: ReadableDBAPIDataset,
expression: Expr = None,
) -> None:
"""Create a lazy evaluated relation to for the dataset of a destination"""
super().__init__(readable_dataset=readable_dataset)
self._expression = expression

@property
def query(self) -> Any:
"""build the query"""

from dlt.common.libs.ibis import ibis, sqlglot

destination_type = self._dataset._destination.destination_type
target_dialect = DIALECT_MAP[destination_type]

# render sql directly if possible
if target_dialect not in TRANSPILE_VIA_MAP:
return ibis.to_sql(self._expression, dialect=target_dialect)

# here we need to transpile first
transpile_via = TRANSPILE_VIA_MAP[target_dialect]
sql = ibis.to_sql(self._expression, dialect=transpile_via)
sql = sqlglot.transpile(sql, read=transpile_via, write=target_dialect)[0]
return sql

@property
def columns_schema(self) -> TTableSchemaColumns:
return self.compute_columns_schema()

@columns_schema.setter
def columns_schema(self, new_value: TTableSchemaColumns) -> None:
raise NotImplementedError("columns schema in ReadableDBAPIRelation can only be computed")

def compute_columns_schema(self) -> TTableSchemaColumns:
"""provide schema columns for the cursor, may be filtered by selected columns"""
# TODO: provide column lineage tracing with sqlglot lineage
return None

def _proxy_expression_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
"""Proxy method calls to the underlying ibis expression, allowing to wrap the resulting expression in a new relation"""

# Get the method from the expression
method = getattr(self._expression, method_name)

# unwrap args and kwargs if they are relations
args = tuple(
arg._expression if isinstance(arg, ReadableIbisRelation) else arg for arg in args
)
kwargs = {
k: v._expression if isinstance(v, ReadableIbisRelation) else v
for k, v in kwargs.items()
}

# casefold string params, we assume these are column names
args = tuple(
self.sql_client.capabilities.casefold_identifier(arg) if isinstance(arg, str) else arg
for arg in args
)
kwargs = {
k: self.sql_client.capabilities.casefold_identifier(v) if isinstance(v, str) else v
for k, v in kwargs.items()
}

# Call it with provided args
result = method(*args, **kwargs)

# If result is an ibis expression, wrap it in a new relation else return raw result
if isinstance(result, Expr):
return self.__class__(readable_dataset=self._dataset, expression=result)
return result

def __getattr__(self, name: str) -> Any:
"""Wrap all callable attributes of the expression"""

attr = getattr(self._expression, name, None)

# try casefolded name for ibis columns access
if attr is None:
name = self.sql_client.capabilities.casefold_identifier(name)
attr = getattr(self._expression, name, None)

if attr is None:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")

if not callable(attr):
return attr
return partial(self._proxy_expression_method, name)

def __getitem__(self, columns: Union[str, Sequence[str]]) -> "ReadableIbisRelation":
# casefold column-names
columns = [self.sql_client.capabilities.casefold_identifier(col) for col in columns]
expr = self._expression[columns]
return self.__class__(readable_dataset=self._dataset, expression=expr)

# forward ibis methods defined on interface
def limit(self, limit: int) -> "ReadableIbisRelation":
"""limit the result to 'limit' items"""
return self._proxy_expression_method("limit", limit) # type: ignore

def head(self, limit: int = 5) -> "ReadableIbisRelation":
"""limit the result to 5 items by default"""
return self._proxy_expression_method("head", limit) # type: ignore

def select(self, *columns: str) -> "ReadableIbisRelation":
"""set which columns will be selected"""
return self._proxy_expression_method("select", *columns) # type: ignore
Loading

0 comments on commit ff330b6

Please sign in to comment.