Skip to content

Commit

Permalink
add timestamps tests to fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Dec 1, 2023
1 parent e3d8fd0 commit 0d6e299
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 14 deletions.
6 changes: 0 additions & 6 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ impl DataType {
(s, o) if s.is_physical() && o.is_physical() => {
Ok((Boolean, None, try_physical_supertype(s, o)?))
}
// To maintain existing behaviour. TODO: cleanup
(Date, o) | (o, Date) if o.is_physical() && o.clone() != Boolean => Ok((
Boolean,
None,
try_physical_supertype(&Date.to_physical(), o)?,
)),
(Timestamp(..) | Date, Timestamp(..) | Date) => {
let intermediate_type = try_get_supertype(self, other)?;
let pt = intermediate_type.to_physical();
Expand Down
77 changes: 69 additions & 8 deletions tests/expressions/typing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import itertools
import sys

import pytz

if sys.version_info < (3, 8):
pass
else:
Expand Down Expand Up @@ -33,12 +35,6 @@
(DataType.bool(), pa.array([True, False, None], type=pa.bool_())),
(DataType.null(), pa.array([None, None, None], type=pa.null())),
(DataType.binary(), pa.array([b"1", b"2", None], type=pa.binary())),
(DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())),
# TODO(jay): Some of the fixtures are broken/become very complicated when testing against timestamps
# (
# DataType.timestamp(TimeUnit.ms()),
# pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp("ms")),
# ),
]

ALL_DATATYPES_BINARY_PAIRS = list(itertools.product(ALL_DTYPES, repeat=2))
Expand All @@ -59,10 +55,75 @@ def binary_data_fixture(request) -> tuple[Series, Series]:
return (s1, s2)


ALL_TEMPORAL_DTYPES = [
(DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())),
*[
(
DataType.timestamp(unit),
pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp(unit)),
)
for unit in ["ns", "us", "ms"]
],
*[
(
DataType.timestamp(unit, "US/Eastern"),
pa.array(
[
datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("US/Eastern")),
datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("US/Eastern")),
None,
],
type=pa.timestamp(unit, "US/Eastern"),
),
)
for unit in ["ns", "us", "ms"]
],
*[
(
DataType.timestamp(unit, "Africa/Accra"),
pa.array(
[
datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("Africa/Accra")),
datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("Africa/Accra")),
None,
],
type=pa.timestamp(unit, "Africa/Accra"),
),
)
for unit in ["ns", "us", "ms"]
],
]

ALL_TEMPORAL_DATATYPES_BINARY_PAIRS = [
((dt1, data1), (dt2, data2))
for (dt1, data1), (dt2, data2) in itertools.product(ALL_TEMPORAL_DTYPES, repeat=2)
if not (
pa.types.is_timestamp(data1.type)
and pa.types.is_timestamp(data2.type)
and (data1.type.tz is None) ^ (data2.type.tz is None)
)
]


@pytest.fixture(
scope="module",
params=ALL_TEMPORAL_DATATYPES_BINARY_PAIRS,
ids=[f"{dt1}-{dt2}" for (dt1, _), (dt2, _) in ALL_TEMPORAL_DATATYPES_BINARY_PAIRS],
)
def binary_temporal_data_fixture(request) -> tuple[Series, Series]:
"""Returns binary permutation of Series' of all DataType pairs"""
(dt1, data1), (dt2, data2) = request.param
s1 = Series.from_arrow(data1, name="lhs")
assert s1.datatype() == dt1
s2 = Series.from_arrow(data2, name="rhs")
assert s2.datatype() == dt2
return (s1, s2)


@pytest.fixture(
scope="module",
params=ALL_DTYPES,
ids=[f"{dt}" for (dt, _) in ALL_DTYPES],
params=ALL_DTYPES + ALL_TEMPORAL_DTYPES,
ids=[f"{dt}" for (dt, _) in ALL_DTYPES + ALL_TEMPORAL_DTYPES],
)
def unary_data_fixture(request) -> Series:
"""Returns unary permutation of Series' of all DataType pairs"""
Expand Down
11 changes: 11 additions & 0 deletions tests/expressions/typing/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,14 @@ def test_comparable(binary_data_fixture, op):
run_kernel=lambda: op(lhs, rhs),
resolvable=comparable_type_validation(lhs.datatype(), rhs.datatype()),
)


@pytest.mark.parametrize("op", [ops.eq, ops.ne, ops.lt, ops.le, ops.gt, ops.ge])
def test_temporal_comparable(binary_temporal_data_fixture, op):
lhs, rhs = binary_temporal_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=binary_temporal_data_fixture,
expr=op(col(lhs.name()), col(rhs.name())),
run_kernel=lambda: op(lhs, rhs),
resolvable=comparable_type_validation(lhs.datatype(), rhs.datatype()),
)

0 comments on commit 0d6e299

Please sign in to comment.