Skip to content

Commit

Permalink
[FEAT] Add struct get syntactic sugar (#2367)
Browse files Browse the repository at this point in the history
Adds the ability to query struct and map fields by using the dot syntax,
such as `col("a.b")` turning into `col("a").struct.get("b")`.

This PR also includes a minor refactor of agg expression checking and
extraction, just moving it out of the builder and into the
`resolve_expr` and `resolve_aggexpr` functions that also deal with the
syntactic sugar. I changed this since we were talking about how brittle
it would be to do the syntactic sugar conversion in the builder and
realized that it applies to aggregations too.
  • Loading branch information
kevinzwang authored Jun 28, 2024
1 parent 7422f2f commit 50d9b80
Show file tree
Hide file tree
Showing 22 changed files with 479 additions and 228 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,7 @@ def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ...
def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ...
def series_lit(item: PySeries) -> PyExpr: ...
def udf(func: Callable, expressions: list[PyExpr], return_dtype: PyDataType) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...

class PySeries:
@staticmethod
Expand Down
12 changes: 3 additions & 9 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,7 @@
from daft.api_annotations import DataframePublicAPI
from daft.context import get_context
from daft.convert import InputListType
from daft.daft import (
FileFormat,
IOConfig,
JoinStrategy,
JoinType,
ResourceRequest,
)
from daft.daft import FileFormat, IOConfig, JoinStrategy, JoinType, ResourceRequest, resolve_expr
from daft.dataframe.preview import DataFramePreview
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
Expand Down Expand Up @@ -769,8 +763,8 @@ def __getitem__(self, item: Union[slice, int, str, Iterable[Union[str, int]]]) -
return result
elif isinstance(item, str):
schema = self._builder.schema()
field = schema[item]
return col(field.name)
expr, _ = resolve_expr(col(item)._expr, schema._schema)
return Expression._from_pyexpr(expr)
elif isinstance(item, Iterable):
schema = self._builder.schema()

Expand Down
59 changes: 40 additions & 19 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,27 +1735,48 @@ def get(self, key: Expression) -> Expression:
"""Retrieves the value for a key in a map column
Example:
>>> import pyarrrow as pa
>>> import pyarrow as pa
>>> import daft
>>> pa_array = pa.array([[(1, 2)],[],[(2,1)]], type=pa.map_(pa.int64(), pa.int64()))
>>> pa_array = pa.array([[("a", 1)],[],[("b",2)]], type=pa.map_(pa.string(), pa.int64()))
>>> df = daft.from_arrow(pa.table({"map_col": pa_array}))
>>> df = df.with_column("1", df["map_col"].map.get(1))
>>> df.show()
╭───────────────────────────────────────┬───────╮
│ map_col ┆ 1 │
│ --- ┆ --- │
│ Map[Struct[key: Int64, value: Int64]] ┆ Int64 │
╞═══════════════════════════════════════╪═══════╡
│ [{key: 1, ┆ 2 │
│ value: 2, ┆ │
│ }] ┆ │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [] ┆ None │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [{key: 2, ┆ None │
│ value: 1, ┆ │
│ }] ┆ │
╰───────────────────────────────────────┴───────╯
>>> df1 = df.with_column("a", df["map_col"].map.get("a"))
>>> df1.show()
╭───────────┬───────╮
│ map_col ┆ a │
│ --- ┆ --- │
│ Map[Utf8] ┆ Int64 │
╞═══════════╪═══════╡
│ [{key: a, ┆ 1 │
│ value: 1, ┆ │
│ }] ┆ │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [] ┆ None │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [{key: b, ┆ None │
│ value: 2, ┆ │
│ }] ┆ │
╰───────────┴───────╯
(Showing first 3 of 3 rows)
>>>
>>> # you may also use the "column.key" syntax to get map values
>>> df2 = df.with_column("b", df["map_col.b"])
>>> df2.show()
╭───────────┬───────╮
│ map_col ┆ b │
│ --- ┆ --- │
│ Map[Utf8] ┆ Int64 │
╞═══════════╪═══════╡
│ [{key: a, ┆ None │
│ value: 1, ┆ │
│ }] ┆ │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [] ┆ None │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [{key: b, ┆ 2 │
│ value: 2, ┆ │
│ }] ┆ │
╰───────────┴───────╯
(Showing first 3 of 3 rows)
Args:
key: the key to retrieve
Expand Down
1 change: 1 addition & 0 deletions src/daft-dsl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ daft-core = {path = "../daft-core", default-features = false}
daft-io = {path = "../daft-io", default-features = false}
daft-sketch = {path = "../daft-sketch", default-features = false}
indexmap = {workspace = true}
itertools = {workspace = true}
pyo3 = {workspace = true, optional = true}
pyo3-log = {workspace = true, optional = true}
serde = {workspace = true}
Expand Down
Loading

0 comments on commit 50d9b80

Please sign in to comment.