Skip to content

Commit

Permalink
add prototype that replaces our own sql with ibis expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Nov 10, 2024
1 parent 95ca6e6 commit ff13003
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 81 deletions.
13 changes: 0 additions & 13 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,19 +523,6 @@ def fetchone(self) -> Optional[Tuple[Any, ...]]:
"""fetch first item as python tuple"""
...

# modifying access parameters
def limit(self, limit: int) -> "SupportsReadableRelation":
"""limit the result to 'limit' items"""
...

def head(self, limit: int = 5) -> "SupportsReadableRelation":
"""limit the result to 5 items by default"""
...

def select(self, *columns: str) -> "SupportsReadableRelation":
"""set which columns will be selected"""
...

@overload
def __getitem__(self, column: str) -> "SupportsReadableRelation": ...

Expand Down
158 changes: 90 additions & 68 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@
DestinationClientDwhConfiguration,
)

from functools import partial

from dlt.common.schema.typing import TTableSchemaColumns
from dlt.destinations.sql_client import SqlClientBase, WithSqlClient
from dlt.common.schema import Schema
from dlt.common.exceptions import DltException


# TODO: move ibis dependencies to libs/ibis
import ibis

from ibis import Expr


class DatasetException(DltException):
pass

Expand All @@ -43,29 +51,25 @@ def __init__(self, column_name: str) -> None:
super().__init__(msg)


# TODO: provide ibis expression typing for the readable relation
class ReadableDBAPIRelation(SupportsReadableRelation):
def __init__(
self,
*,
readable_dataset: "ReadableDBAPIDataset",
expression: Expr = None,
provided_query: Any = None,
table_name: str = None,
limit: int = None,
selected_columns: Sequence[str] = None,
) -> None:
"""Create a lazy evaluated relation to for the dataset of a destination"""

# NOTE: we can keep an assertion here, this class will not be created by the user
assert bool(table_name) != bool(
provided_query
), "Please provide either an sql query OR a table_name"
assert (
expression != None
) or provided_query, "Please provide either an sql query OR an ibis expression"

self._dataset = readable_dataset

self._provided_query = provided_query
self._table_name = table_name
self._limit = limit
self._selected_columns = selected_columns
self._expression = expression
self._provided_query = None

# wire protocol functions
self.df = self._wrap_func("df") # type: ignore
Expand All @@ -92,27 +96,8 @@ def query(self) -> Any:
if self._provided_query:
return self._provided_query

table_name = self.sql_client.make_qualified_table_name(
self.schema.naming.normalize_path(self._table_name)
)

maybe_limit_clause_1 = ""
maybe_limit_clause_2 = ""
if self._limit:
maybe_limit_clause_1, maybe_limit_clause_2 = self.sql_client._limit_clause_sql(
self._limit
)

selector = "*"
if self._selected_columns:
selector = ",".join(
[
self.sql_client.escape_column_name(self.schema.naming.normalize_path(c))
for c in self._selected_columns
]
)

return f"SELECT {maybe_limit_clause_1} {selector} FROM {table_name} {maybe_limit_clause_2}"
# TODO: select the correct dialect for the destination type
return ibis.to_sql(self._expression, dialect="duckdb")

@property
def columns_schema(self) -> TTableSchemaColumns:
Expand All @@ -125,6 +110,10 @@ def columns_schema(self, new_value: TTableSchemaColumns) -> None:
def compute_columns_schema(self) -> TTableSchemaColumns:
"""provide schema columns for the cursor, may be filtered by selected columns"""

# TODO: if there were no joins, we can return the schema from the table
# for the prototype disable this
return None

columns_schema = (
self.schema.tables.get(self._table_name, {}).get("columns", {}) if self.schema else {}
)
Expand All @@ -143,6 +132,41 @@ def compute_columns_schema(self) -> TTableSchemaColumns:

return filtered_columns

def _proxy_expression_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
"""Proxy method calls to the underlying ibis expression, allowing to wrap the resulting expression in a new relation"""
# Get the method from the expression
method = getattr(self._expression, method_name)
# if any of the args is a relation, we need to unwrap it
unwrapped_args = [
arg._expression if isinstance(arg, ReadableDBAPIRelation) else arg for arg in args
]
# if any of the kwargs is a relation, we need to unwrap it
unwrapped_kwargs = {
k: v._expression if isinstance(v, ReadableDBAPIRelation) else v
for k, v in kwargs.items()
}
# Call it with provided args
result = method(*unwrapped_args, **unwrapped_kwargs)
# If result is an ibis expression, wrap it in a new relation
if isinstance(result, Expr):
print("wrapped")
return self.__class__(readable_dataset=self._dataset, expression=result)
# Otherwise return the raw result
return result

def __getattr__(self, name: str) -> Any:
"""Forward any unknown attributes/methods to the underlying expression"""
if not hasattr(self._expression, name):
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
attr = getattr(self._expression, name)
if not callable(attr):
return attr
return partial(self._proxy_expression_method, name)

def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation":
expr = self._expression[columns]
return self.__class__(readable_dataset=self._dataset, expression=expr)

@contextmanager
def cursor(self) -> Generator[SupportsReadableRelation, Any, Any]:
"""Gets a DBApiCursor for the current relation"""
Expand Down Expand Up @@ -178,40 +202,10 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
def __copy__(self) -> "ReadableDBAPIRelation":
return self.__class__(
readable_dataset=self._dataset,
expression=self._expression,
provided_query=self._provided_query,
table_name=self._table_name,
limit=self._limit,
selected_columns=self._selected_columns,
)

def limit(self, limit: int) -> "ReadableDBAPIRelation":
if self._provided_query:
raise ReadableRelationHasQueryException("limit")
rel = self.__copy__()
rel._limit = limit
return rel

def select(self, *columns: str) -> "ReadableDBAPIRelation":
if self._provided_query:
raise ReadableRelationHasQueryException("select")
rel = self.__copy__()
rel._selected_columns = columns
# NOTE: the line below will ensure that no unknown columns are selected if
# schema is known
rel.compute_columns_schema()
return rel

def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation":
if isinstance(columns, str):
return self.select(columns)
elif isinstance(columns, Sequence):
return self.select(*columns)
else:
raise TypeError(f"Invalid argument type: {type(columns).__name__}")

def head(self, limit: int = 5) -> "ReadableDBAPIRelation":
return self.limit(limit)


class ReadableDBAPIDataset(SupportsReadableDataset):
"""Access to dataframes and arrowtables in the destination dataset via dbapi"""
Expand Down Expand Up @@ -287,18 +281,46 @@ def __call__(self, query: Any) -> ReadableDBAPIRelation:
return ReadableDBAPIRelation(readable_dataset=self, provided_query=query) # type: ignore[abstract]

def table(self, table_name: str) -> SupportsReadableRelation:
return ReadableDBAPIRelation(
readable_dataset=self,
table_name=table_name,
) # type: ignore[abstract]
# NOTE: to be able to create an unbound ibis table we need to access the schema
# and if this is not present, this will not be fully lazy bc the dataset needs to be
# queried to get the schema
if table_name not in self.schema.tables:
raise Exception(f"Table {table_name} not found in schema")
table_schema = self.schema.tables[table_name]

# Convert dlt table schema columns to ibis schema
# TODO: convert all types properly
ibis_schema = {}
for col_name, col_info in table_schema.get("columns", {}).items():
#
col_mapping = {
"text": "string",
"double": "float64",
"bool": "boolean",
"timestamp": "timestamp",
"bigint": "int64",
"binary": "binary",
"json": "string", # Store JSON as string in ibis
"decimal": "decimal",
"wei": "int64", # Wei is a large integer
"date": "date",
"time": "time",
}

col_type = col_info.get("data_type", "string")
ibis_schema[col_name] = col_mapping.get(col_type, "string")

print(ibis_schema)
# create unbound ibis table and return in dlt wrapper
unbound_table = ibis.table(schema=ibis_schema, name=table_name)
return ReadableDBAPIRelation(readable_dataset=self, expression=unbound_table)

def __getitem__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via dict notation"""
return self.table(table_name)

def __getattr__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via property notation"""
return self.table(table_name)


def dataset(
Expand Down
46 changes: 46 additions & 0 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,49 @@ def test_standalone_dataset(populated_pipeline: Pipeline) -> None:
)
assert dataset.schema.name == "some_other_schema"
assert "other_table" in dataset.schema.tables


@pytest.mark.no_load
@pytest.mark.essential
@pytest.mark.parametrize(
"populated_pipeline",
configs,
indirect=True,
ids=lambda x: x.name,
)
def test_ibis_expression_relation(populated_pipeline: Pipeline) -> None:
# NOTE: we could generalize this with a context for certain deps
dataset = populated_pipeline._dataset()
total_records = _total_records(populated_pipeline)

items_table = dataset.table("items")
double_items_table = dataset.table("double_items")

# check full table access
df = items_table.df()
assert len(df.index) == total_records

df = double_items_table.df()
assert len(df.index) == total_records

# check limit
df = items_table.limit(5).df()
assert len(df.index) == 5

# check chained expression with join, column selection, order by and limit
joined_table = (
items_table.join(double_items_table, items_table.id == double_items_table.id)[
["id", "double_id"]
]
.order_by("id")
.limit(20)
)
table = joined_table.fetchall()
assert len(table) == 20
assert list(table[0]) == [0, 0]
assert list(table[5]) == [5, 10]
assert list(table[10]) == [10, 20]

# check aggregate of first 20 items
agg_table = items_table.order_by("id").limit(20).aggregate(sum_id=items_table.id.sum())
assert agg_table.fetchone()[0] == reduce(lambda a, b: a + b, range(20))

0 comments on commit ff13003

Please sign in to comment.