Skip to content

Commit

Permalink
feat: Improve array(), map(), and struct
Browse files Browse the repository at this point in the history
fixes ibis-project#8289

This does a lot of changes. It was hard for me to separate them out as I implemented them. But now that it's all hashed out, I can try to split this up into separate commits if you want. But that might be sorta hard in
some cases.

One this is adding support for passing in None to all these constructors.
These use the new `ibis.null(<type>)` API to return `op.Literal(None, <type>)`s

Make these constructors idempotent: you can
pass in existing Expressions into array(), etc.
The type argument for all of these now always has an effect, not just when passing in python literals. So basically it acts like a cast.

A big structural change is that now ops.Array has an optional
attribute "dtype", so if you pass in a 0-length sequence
of values the op still knows what dtype it is.

Several of the backends were always broken here, they just weren't getting caught. I marked them as broken, we can fix them in a followup.

You can test this locally with eg
`pytest -m <backend> -k factory ibis/backends/tests/test_array.py  ibis/backends/tests/test_map.py ibis/backends/tests/test_struct.py`

Also, fix a typing bug: map() can accept ArrayValues, not just ArrayColumns.

Also, fix executing Literal(None) on pandas and polars, 0-length arrays on polars

Also, fixing converting dtypes on clickhouse, Structs should be converted to nonnullable dtypes.

Also, implement ops.StructColumn on pandas and dask
  • Loading branch information
NickCrews committed May 9, 2024
1 parent 4707c44 commit 32e7636
Show file tree
Hide file tree
Showing 13 changed files with 301 additions and 85 deletions.
8 changes: 7 additions & 1 deletion ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,17 @@ def mapper(df, cases):
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
return cls.rowwise(
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
)

@classmethod
def visit(cls, op: ops.StructColumn, names, values):
return cls.rowwise(
lambda row: dict(zip(names, row)), values, name=op.name, dtype=object
)

@classmethod
def visit(cls, op: ops.ArrayConcat, arg):
dtype = PandasType.from_ibis(op.dtype)
Expand Down
18 changes: 12 additions & 6 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ def visit(cls, op: ops.Node, **kwargs):

@classmethod
def visit(cls, op: ops.Literal, value, dtype):
if value is None:
return None
if dtype.is_interval():
value = pd.Timedelta(value, dtype.unit.short)
elif dtype.is_array():
value = np.array(value)
elif dtype.is_date():
value = pd.Timestamp(value, tz="UTC").tz_localize(None)
return pd.Timedelta(value, dtype.unit.short)
if dtype.is_array():
return np.array(value)
if dtype.is_date():
return pd.Timestamp(value, tz="UTC").tz_localize(None)
return value

@classmethod
Expand Down Expand Up @@ -220,9 +222,13 @@ def visit(cls, op: ops.FindInSet, needle, values):
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)

@classmethod
def visit(cls, op: ops.StructColumn, names, values):
return cls.rowwise(lambda row: dict(zip(names, row)), values)

@classmethod
def visit(cls, op: ops.ArrayConcat, arg):
return cls.rowwise(lambda row: np.concatenate(row.values), arg)
Expand Down
13 changes: 8 additions & 5 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ def _make_duration(value, dtype):
def literal(op, **_):
value = op.value
dtype = op.dtype
typ = PolarsType.from_ibis(dtype)

if value is None:
return pl.lit(None, dtype=typ)
if dtype.is_array():
value = pl.Series("", value)
typ = PolarsType.from_ibis(dtype)
val = pl.lit(value, dtype=typ)
return val.implode()
elif dtype.is_struct():
Expand All @@ -105,7 +107,6 @@ def literal(op, **_):
elif dtype.is_binary():
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
return pl.lit(op.value, dtype=typ)


Expand Down Expand Up @@ -980,9 +981,11 @@ def array_concat(op, **kw):


@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)
def array_literal(op, **kw):
if len(op.exprs) > 0:
return pl.concat_list([translate(col, **kw) for col in op.exprs])
else:
return pl.lit([], dtype=PolarsType.from_ibis(op.dtype))


@translate.register(ops.ArrayCollect)
Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,8 +970,11 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)
def visit_Array(self, op, *, exprs, dtype):
result = self.f.array(*exprs)
if len(exprs) == 0:
return self.cast(result, dtype)
return result

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,8 +1007,10 @@ class ClickHouseType(SqlglotType):
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
"""Convert a sqlglot type to an ibis type."""
typ = super().from_ibis(dtype)
if dtype.nullable and not (dtype.is_map() or dtype.is_array()):
# map cannot be nullable in clickhouse
# nested types cannot be nullable in clickhouse
if dtype.nullable and not (
dtype.is_map() or dtype.is_array() or dtype.is_struct()
):
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
else:
return typ
Expand Down
38 changes: 38 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PySparkAnalysisException,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict

pytestmark = [
Expand Down Expand Up @@ -71,6 +72,43 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert con.execute(a2) == [1, 2, 3]

typed = ibis.array([1, 2, 3], type="array<string>")
assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(a, type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down
13 changes: 12 additions & 1 deletion ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
from ibis.backends.tests.errors import PsycoPg2InternalError, Py4JJavaError
from ibis.common.annotations import ValidationError

pytestmark = [
pytest.mark.never(
Expand Down Expand Up @@ -121,6 +122,16 @@ def test_map_values_nulls(con, map):
assert con.execute(map.values()) is None


def test_map_factory(con):
assert con.execute(ibis.map(None, type="map<string, string>")) is None
with pytest.raises(ValidationError):
ibis.map(None)
with pytest.raises(ValidationError):
ibis.map(None, type="array<string>")
with pytest.raises(ValidationError):
ibis.map({1: 2}, type="array<string>")


@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
Expand Down Expand Up @@ -669,6 +680,6 @@ def test_map_keys_unnest(backend):

@mark_notimpl_risingwave_hstore
def test_map_contains_null(con):
expr = ibis.map(["a"], ibis.literal([None], type="array<string>"))
expr = ibis.map(["a"], ibis.array([None], type="array<string>"))
assert con.execute(expr.contains("a"))
assert not con.execute(expr.contains("b"))
59 changes: 51 additions & 8 deletions ibis/backends/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,58 @@
Py4JJavaError,
PySparkAnalysisException,
)
from ibis.common.exceptions import IbisError, OperationNotDefinedError
from ibis.common.annotations import ValidationError
from ibis.common.exceptions import IbisError

pytestmark = [
pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"),
pytest.mark.notyet(["impala"]),
pytest.mark.notimpl(["datafusion", "druid", "oracle", "exasol"]),
]

mark_notimpl_postgres_literals = pytest.mark.notimpl(
"postgres", reason="struct literals not implemented", raises=PsycoPg2SyntaxError
)


@pytest.mark.broken("postgres", reason="JSON handling is buggy")
def test_struct_factory(con):
s = ibis.struct({"a": 1, "b": 2})
assert con.execute(s) == {"a": 1, "b": 2}

s2 = ibis.struct(s)
assert con.execute(s2) == {"a": 1, "b": 2}

typed = ibis.struct({"a": 1, "b": 2}, type="struct<a: string, b: string>")
assert con.execute(typed) == {"a": "1", "b": "2"}

typed2 = ibis.struct(s, type="struct<a: string, b: string>")
assert con.execute(typed2) == {"a": "1", "b": "2"}


def test_struct_factory_empty():
with pytest.raises(ValidationError):
ibis.struct({})
with pytest.raises(ValidationError):
ibis.struct({}, type="struct<>")
with pytest.raises(ValidationError):
ibis.struct({}, type="struct<a: float64, b: float64>")


@mark_notimpl_postgres_literals
@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
@pytest.mark.broken(
"polars", reason=r"pl.lit(None, type='struct<a: int64>') gives {'a': None}"
)
def test_struct_factory_null(con):
with pytest.raises(ValidationError):
ibis.struct(None)
none_typed = ibis.struct(None, type="struct<a: float64, b: float>")
assert none_typed.type() == dt.Struct(fields={"a": dt.float64, "b": dt.float64})
assert con.execute(none_typed) is None


@pytest.mark.notimpl(["dask"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -79,6 +123,9 @@ def test_all_fields(struct, struct_df):

@pytest.mark.notimpl(["postgres", "risingwave"])
@pytest.mark.parametrize("field", ["a", "b", "c"])
@pytest.mark.notyet(
["flink"], reason="flink doesn't support creating struct columns from literals"
)
def test_literal(backend, con, field):
query = _STRUCT_LITERAL[field]
dtype = query.type().to_pandas()
Expand All @@ -88,7 +135,7 @@ def test_literal(backend, con, field):
backend.assert_series_equal(result, expected.astype(dtype))


@pytest.mark.notimpl(["postgres"])
@mark_notimpl_postgres_literals
@pytest.mark.parametrize("field", ["a", "b", "c"])
@pytest.mark.notyet(
["clickhouse"], reason="clickhouse doesn't support nullable nested types"
Expand All @@ -101,7 +148,7 @@ def test_null_literal(backend, con, field):
backend.assert_series_equal(result, expected)


@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"])
@pytest.mark.notimpl(["postgres", "risingwave"])
def test_struct_column(alltypes, df):
t = alltypes
expr = t.select(s=ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col)))
Expand All @@ -113,7 +160,7 @@ def test_struct_column(alltypes, df):
tm.assert_frame_equal(result, expected)


@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "polars"])
@pytest.mark.notimpl(["postgres", "risingwave", "polars"])
@pytest.mark.notyet(
["flink"], reason="flink doesn't support creating struct columns from collect"
)
Expand All @@ -138,9 +185,6 @@ def test_collect_into_struct(alltypes):
assert len(val.loc[result.group == "1"].iat[0]["key"]) == 730


@pytest.mark.notimpl(
["postgres"], reason="struct literals not implemented", raises=PsycoPg2SyntaxError
)
@pytest.mark.notimpl(
["risingwave"],
reason="struct literals not implemented",
Expand Down Expand Up @@ -253,7 +297,6 @@ def test_keyword_fields(con, nullable):
raises=PolarsColumnNotFoundError,
reason="doesn't seem to support IN-style subqueries on structs",
)
@pytest.mark.notimpl(["pandas", "dask"], raises=OperationNotDefinedError)
@pytest.mark.xfail_version(
pyspark=["pyspark<3.5"],
reason="requires pyspark 3.5",
Expand Down
19 changes: 14 additions & 5 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,32 @@
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.annotations import ValidationError, attribute
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Unary, Value


@public
class Array(Value):
exprs: VarTuple[Value]
dtype: Optional[dt.Array] = None

def __init__(self, exprs, dtype: dt.Array | None = None):
if len(exprs) == 0:
if dtype is None:
raise ValidationError("If values is empty, dtype must be provided")
if not isinstance(dtype, dt.Array):
raise ValidationError(f"dtype must be an array, got {dtype}")
elif dtype is None:
dtype = dt.Array(rlz.highest_precedence_dtype(exprs))
super().__init__(exprs=exprs, dtype=dtype)

@attribute
def shape(self):
if len(self.exprs) == 0:
return ds.scalar
return rlz.highest_precedence_shape(self.exprs)

@attribute
def dtype(self):
return dt.Array(rlz.highest_precedence_dtype(self.exprs))


@public
class ArrayLength(Unary):
Expand Down
9 changes: 5 additions & 4 deletions ibis/expr/operations/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ class StructColumn(Value):

shape = rlz.shape_like("values")

def __init__(self, names, values):
def __init__(self, names: VarTuple[str], values: VarTuple[Value]):
if len(names) == 0:
raise ValidationError("StructColumn must have at least one field")
if len(names) != len(values):
raise ValidationError(
f"Length of names ({len(names)}) does not match length of "
Expand All @@ -43,6 +45,5 @@ def __init__(self, names, values):
super().__init__(names=names, values=values)

@attribute
def dtype(self) -> dt.DataType:
dtypes = (value.dtype for value in self.values)
return dt.Struct.from_tuples(zip(self.names, dtypes))
def dtype(self) -> dt.Struct:
return dt.Struct.from_tuples(zip(self.names, (v.dtype for v in self.values)))
Loading

0 comments on commit 32e7636

Please sign in to comment.