Skip to content

Commit

Permalink
casefold identifiers for ibis wrapper calss
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Nov 19, 2024
1 parent 92c6ef9 commit 7762611
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
32 changes: 26 additions & 6 deletions dlt/destinations/ibis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,26 @@ def _proxy_expression_method(self, method_name: str, *args: Any, **kwargs: Any)
method = getattr(self._expression, method_name)

# unwrap args and kwargs if they are relations
unwrapped_args = [
args = tuple(
arg._expression if isinstance(arg, ReadableIbisRelation) else arg for arg in args
]
unwrapped_kwargs = {
)
kwargs = {
k: v._expression if isinstance(v, ReadableIbisRelation) else v
for k, v in kwargs.items()
}

# casefold string params, this may break some methods..
args = tuple(
self.sql_client.capabilities.casefold_identifier(arg) if isinstance(arg, str) else arg
for arg in args
)
kwargs = {
k: self.sql_client.capabilities.casefold_identifier(v) if isinstance(v, str) else v
for k, v in kwargs.items()
}

# Call it with provided args
result = method(*unwrapped_args, **unwrapped_kwargs)
result = method(*args, **kwargs)

# If result is an ibis expression, wrap it in a new relation else return raw result
if isinstance(result, Expr):
Expand All @@ -184,14 +194,24 @@ def _proxy_expression_method(self, method_name: str, *args: Any, **kwargs: Any)

def __getattr__(self, name: str) -> Any:
"""Wrap all callable attributes of the expression"""
if not hasattr(self._expression, name):

attr = getattr(self._expression, name, None)

# try casefolded name for ibis columns access
if attr is None:
name = self.sql_client.capabilities.casefold_identifier(name)
attr = getattr(self._expression, name, None)

if attr is None:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
attr = getattr(self._expression, name)

if not callable(attr):
return attr
return partial(self._proxy_expression_method, name)

def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation":
# casefold column-names
columns = [self.sql_client.capabilities.casefold_identifier(col) for col in columns]
expr = self._expression[columns]
return self.__class__(readable_dataset=self._dataset, expression=expr)

Expand Down
23 changes: 9 additions & 14 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,6 @@ def test_ibis_expression_relation(populated_pipeline: Pipeline) -> None:
items_table = dataset.table("items")
double_items_table = dataset.table("double_items")

map_i = lambda x: x
if populated_pipeline.destination.destination_type == "dlt.destinations.snowflake":
map_i = lambda x: x.upper()

# check full table access
df = items_table.df()
assert len(df.index) == total_records
Expand All @@ -425,11 +421,10 @@ def test_ibis_expression_relation(populated_pipeline: Pipeline) -> None:

# check chained expression with join, column selection, order by and limit
joined_table = (
items_table.join(
double_items_table,
getattr(items_table, map_i("id")) == getattr(double_items_table, map_i("id")),
)[[map_i("id"), map_i("double_id")]]
.order_by(map_i("id"))
items_table.join(double_items_table, items_table.id == double_items_table.id)[
["id", "double_id"]
]
.order_by("id")
.limit(20)
)
table = joined_table.fetchall()
Expand All @@ -439,13 +434,13 @@ def test_ibis_expression_relation(populated_pipeline: Pipeline) -> None:
assert list(table[10]) == [10, 20]

# check aggregate of first 20 items
agg_table = (
items_table.order_by(map_i("id"))
.limit(20)
.aggregate(sum_id=getattr(items_table, map_i("id")).sum())
)
agg_table = items_table.order_by("id").limit(20).aggregate(sum_id=items_table.id.sum())
assert agg_table.fetchone()[0] == reduce(lambda a, b: a + b, range(20))

# check filtering
filtered_table = items_table.filter(items_table.id < 10)
assert len(filtered_table.fetchall()) == 10

# # NOTE: here we test that dlt column type resolution still works
# # hints should also be preserved via computed reduced schema
# expected_decimal_precision = 10
Expand Down

0 comments on commit 7762611

Please sign in to comment.