Skip to content

Commit

Permalink
Fix lazy binding generator bug, add asymmetric multikey join test.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Dec 22, 2023
1 parent 47e829d commit 62f87e1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
8 changes: 5 additions & 3 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,9 @@ def sort_merge_join(
(left_on, left_source_materializations, left_boundaries),
(right_on, right_source_materializations, right_boundaries),
]:
range_fanout_plan = (
# NOTE: We need to give reduce() an iter(list), since giving it a generator would result in lazy
# binding in this loop.
range_fanout_plan = [
PartitionTaskBuilder[PartitionT](
inputs=[boundaries.partition(), source.partition()],
partial_metadatas=[boundaries.partition_metadata(), source.partition_metadata()],
Expand All @@ -640,12 +642,12 @@ def sort_merge_join(
),
)
for source in consume_deque(source_materializations)
)
]

# Execute a sorting reduce on it.
sorted_plans.append(
reduce(
fanout_plan=range_fanout_plan,
fanout_plan=iter(range_fanout_plan),
reduce_instruction=execution_step.ReduceMergeAndSort(
sort_by=on,
descending=descending,
Expand Down
40 changes: 40 additions & 0 deletions tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,46 @@ def test_inner_join_multikey(join_strategy, make_df, repartition_nparts):
)


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4])
@pytest.mark.parametrize("join_strategy", [None, "hash", "sort_merge", "broadcast"])
def test_inner_join_asymmetric_multikey(join_strategy, make_df, repartition_nparts):
daft_df = make_df(
{
"left_id": [1, None, None],
"left_id2": ["foo1", "foo2", None],
"values_left": ["a1", "b1", "c1"],
},
repartition=repartition_nparts,
)
daft_df2 = make_df(
{
"right_id": [None, None, 1],
"right_id2": ["foo2", None, "foo1"],
"values_right": ["a2", "b2", "c2"],
},
repartition=repartition_nparts,
)
daft_df = daft_df.join(
daft_df2,
left_on=["left_id", "left_id2"],
right_on=["right_id", "right_id2"],
how="inner",
strategy=join_strategy,
)

expected = {
"left_id": [1],
"left_id2": ["foo1"],
"values_left": ["a1"],
"right_id": [1],
"right_id2": ["foo1"],
"values_right": ["c2"],
}
assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "left_id") == sort_arrow_table(
pa.Table.from_pydict(expected), "left_id"
)


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4])
@pytest.mark.parametrize("join_strategy", [None, "hash", "sort_merge", "broadcast"])
def test_inner_join_all_null(join_strategy, make_df, repartition_nparts):
Expand Down

0 comments on commit 62f87e1

Please sign in to comment.