diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 1f3a6c1120..769ddf0203 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -523,19 +523,6 @@ def fetchone(self) -> Optional[Tuple[Any, ...]]: """fetch first item as python tuple""" ... - # modifying access parameters - def limit(self, limit: int) -> "SupportsReadableRelation": - """limit the result to 'limit' items""" - ... - - def head(self, limit: int = 5) -> "SupportsReadableRelation": - """limit the result to 5 items by default""" - ... - - def select(self, *columns: str) -> "SupportsReadableRelation": - """set which columns will be selected""" - ... - @overload def __getitem__(self, column: str) -> "SupportsReadableRelation": ... diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index cffdc0f059..ad211235cc 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -16,12 +16,20 @@ DestinationClientDwhConfiguration, ) +from functools import partial + from dlt.common.schema.typing import TTableSchemaColumns from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.common.schema import Schema from dlt.common.exceptions import DltException +# TODO: move ibis dependencies to libs/ibis +import ibis + +from ibis import Expr + + class DatasetException(DltException): pass @@ -43,29 +51,25 @@ def __init__(self, column_name: str) -> None: super().__init__(msg) +# TODO: provide ibis expression typing for the readable relation class ReadableDBAPIRelation(SupportsReadableRelation): def __init__( self, *, readable_dataset: "ReadableDBAPIDataset", + expression: Expr = None, provided_query: Any = None, - table_name: str = None, - limit: int = None, - selected_columns: Sequence[str] = None, ) -> None: """Create a lazy evaluated relation to for the dataset of a destination""" # NOTE: we can keep an assertion here, this class will not be created by the user - assert bool(table_name) != bool( - provided_query - ), "Please provide either an sql query OR a table_name" + assert ( + expression != None + ) or provided_query, "Please provide either an sql query OR an ibis expression" self._dataset = readable_dataset - - self._provided_query = provided_query - self._table_name = table_name - self._limit = limit - self._selected_columns = selected_columns + self._expression = expression + self._provided_query = None # wire protocol functions self.df = self._wrap_func("df") # type: ignore @@ -92,27 +96,8 @@ def query(self) -> Any: if self._provided_query: return self._provided_query - table_name = self.sql_client.make_qualified_table_name( - self.schema.naming.normalize_path(self._table_name) - ) - - maybe_limit_clause_1 = "" - maybe_limit_clause_2 = "" - if self._limit: - maybe_limit_clause_1, maybe_limit_clause_2 = self.sql_client._limit_clause_sql( - self._limit - ) - - selector = "*" - if self._selected_columns: - selector = ",".join( - [ - self.sql_client.escape_column_name(self.schema.naming.normalize_path(c)) - for c in self._selected_columns - ] - ) - - return f"SELECT {maybe_limit_clause_1} {selector} FROM {table_name} {maybe_limit_clause_2}" + # TODO: select the correct dialect for the destination type + return ibis.to_sql(self._expression, dialect="duckdb") @property def columns_schema(self) -> TTableSchemaColumns: @@ -125,6 +110,10 @@ def columns_schema(self, new_value: TTableSchemaColumns) -> None: def compute_columns_schema(self) -> TTableSchemaColumns: """provide schema columns for the cursor, may be filtered by selected columns""" + # TODO: if there were no joins, we can return the schema from the table + # for the prototype disable this + return None + columns_schema = ( self.schema.tables.get(self._table_name, {}).get("columns", {}) if self.schema else {} ) @@ -143,6 +132,41 @@ def compute_columns_schema(self) -> TTableSchemaColumns: return filtered_columns + def _proxy_expression_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + """Proxy method calls to the underlying ibis expression, allowing to wrap the resulting expression in a new relation""" + # Get the method from the expression + method = getattr(self._expression, method_name) + # if any of the args is a relation, we need to unwrap it + unwrapped_args = [ + arg._expression if isinstance(arg, ReadableDBAPIRelation) else arg for arg in args + ] + # if any of the kwargs is a relation, we need to unwrap it + unwrapped_kwargs = { + k: v._expression if isinstance(v, ReadableDBAPIRelation) else v + for k, v in kwargs.items() + } + # Call it with provided args + result = method(*unwrapped_args, **unwrapped_kwargs) + # If result is an ibis expression, wrap it in a new relation + if isinstance(result, Expr): + print("wrapped") + return self.__class__(readable_dataset=self._dataset, expression=result) + # Otherwise return the raw result + return result + + def __getattr__(self, name: str) -> Any: + """Forward any unknown attributes/methods to the underlying expression""" + if not hasattr(self._expression, name): + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + attr = getattr(self._expression, name) + if not callable(attr): + return attr + return partial(self._proxy_expression_method, name) + + def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation": + expr = self._expression[columns] + return self.__class__(readable_dataset=self._dataset, expression=expr) + @contextmanager def cursor(self) -> Generator[SupportsReadableRelation, Any, Any]: """Gets a DBApiCursor for the current relation""" @@ -178,40 +202,10 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: def __copy__(self) -> "ReadableDBAPIRelation": return self.__class__( readable_dataset=self._dataset, + expression=self._expression, provided_query=self._provided_query, - table_name=self._table_name, - limit=self._limit, - selected_columns=self._selected_columns, ) - def limit(self, limit: int) -> "ReadableDBAPIRelation": - if self._provided_query: - raise ReadableRelationHasQueryException("limit") - rel = self.__copy__() - rel._limit = limit - return rel - - def select(self, *columns: str) -> "ReadableDBAPIRelation": - if self._provided_query: - raise ReadableRelationHasQueryException("select") - rel = self.__copy__() - rel._selected_columns = columns - # NOTE: the line below will ensure that no unknown columns are selected if - # schema is known - rel.compute_columns_schema() - return rel - - def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation": - if isinstance(columns, str): - return self.select(columns) - elif isinstance(columns, Sequence): - return self.select(*columns) - else: - raise TypeError(f"Invalid argument type: {type(columns).__name__}") - - def head(self, limit: int = 5) -> "ReadableDBAPIRelation": - return self.limit(limit) - class ReadableDBAPIDataset(SupportsReadableDataset): """Access to dataframes and arrowtables in the destination dataset via dbapi""" @@ -287,10 +281,39 @@ def __call__(self, query: Any) -> ReadableDBAPIRelation: return ReadableDBAPIRelation(readable_dataset=self, provided_query=query) # type: ignore[abstract] def table(self, table_name: str) -> SupportsReadableRelation: - return ReadableDBAPIRelation( - readable_dataset=self, - table_name=table_name, - ) # type: ignore[abstract] + # NOTE: to be able to create an unbound ibis table we need to access the schema + # and if this is not present, this will not be fully lazy bc the dataset needs to be + # queried to get the schema + if table_name not in self.schema.tables: + raise Exception(f"Table {table_name} not found in schema") + table_schema = self.schema.tables[table_name] + + # Convert dlt table schema columns to ibis schema + # TODO: convert all types properly + ibis_schema = {} + for col_name, col_info in table_schema.get("columns", {}).items(): + # + col_mapping = { + "text": "string", + "double": "float64", + "bool": "boolean", + "timestamp": "timestamp", + "bigint": "int64", + "binary": "binary", + "json": "string", # Store JSON as string in ibis + "decimal": "decimal", + "wei": "int64", # Wei is a large integer + "date": "date", + "time": "time", + } + + col_type = col_info.get("data_type", "string") + ibis_schema[col_name] = col_mapping.get(col_type, "string") + + print(ibis_schema) + # create unbound ibis table and return in dlt wrapper + unbound_table = ibis.table(schema=ibis_schema, name=table_name) + return ReadableDBAPIRelation(readable_dataset=self, expression=unbound_table) def __getitem__(self, table_name: str) -> SupportsReadableRelation: """access of table via dict notation""" @@ -298,7 +321,6 @@ def __getitem__(self, table_name: str) -> SupportsReadableRelation: def __getattr__(self, table_name: str) -> SupportsReadableRelation: """access of table via property notation""" - return self.table(table_name) def dataset( diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index c6019ecf2d..2244aba3e8 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -467,3 +467,49 @@ def test_standalone_dataset(populated_pipeline: Pipeline) -> None: ) assert dataset.schema.name == "some_other_schema" assert "other_table" in dataset.schema.tables + + +@pytest.mark.no_load +@pytest.mark.essential +@pytest.mark.parametrize( + "populated_pipeline", + configs, + indirect=True, + ids=lambda x: x.name, +) +def test_ibis_expression_relation(populated_pipeline: Pipeline) -> None: + # NOTE: we could generalize this with a context for certain deps + dataset = populated_pipeline._dataset() + total_records = _total_records(populated_pipeline) + + items_table = dataset.table("items") + double_items_table = dataset.table("double_items") + + # check full table access + df = items_table.df() + assert len(df.index) == total_records + + df = double_items_table.df() + assert len(df.index) == total_records + + # check limit + df = items_table.limit(5).df() + assert len(df.index) == 5 + + # check chained expression with join, column selection, order by and limit + joined_table = ( + items_table.join(double_items_table, items_table.id == double_items_table.id)[ + ["id", "double_id"] + ] + .order_by("id") + .limit(20) + ) + table = joined_table.fetchall() + assert len(table) == 20 + assert list(table[0]) == [0, 0] + assert list(table[5]) == [5, 10] + assert list(table[10]) == [10, 20] + + # check aggregate of first 20 items + agg_table = items_table.order_by("id").limit(20).aggregate(sum_id=items_table.id.sum()) + assert agg_table.fetchone()[0] == reduce(lambda a, b: a + b, range(20))