Skip to content

Commit

Permalink
[FEAT] [Join Optimizations] Add sort-merge join. (#1755)
Browse files Browse the repository at this point in the history
This PR adds a sort-merge join implementation as a new join strategy,
where each side of the join is sorted on the join keys, and then the
sorted tables are merged. The sort-merge join strategy is chosen
automatically by the query planner if it is expected to be faster than
the hash join and the broadcast join.

Similar to Spark's ability to specify a join algorithm hint, this PR
also exposes a new optional `strategy` arg for `df.join()`, which allows
users (and our Python-level tests) to manually specify a join algorithm;
currently `"hash"`, `"sort_merge"`, and `"broadcast"` are supported,
with the default `None` resulting in the query planner choosing a join
algorithm automatically.

```python
df = left.join(right, on="foo", strategy="sort_merge")
```

Closes #1776 

## Query Planning

The query planner chooses the sort-merge join as its join strategy if
the larger side of the join is range-partitioned, or if the smaller side
of the join is range-partitioned and the larger side is not partitioned.
In the future, we will want to do a sort-merge join:
1. If a downstream operation requires the table to be sorted on the join
key.
2. If neither sides of the join are partitioned AND we determine that
the sort-merge join is faster on unpartitioned data the the hash join
(pending benchmarking).

**NOTE:** We currently only support sort-merge joins on primitive join
keys, see TODO below.

## Query Scheduling

### Full sort-merge join

All partitions for both sides of the join are materialized, upon which
we calculate sort boundaries on samples from both sides of the join.
These combined sort boundaries are used to sort each side of the join.
Once each side is sorted with the same sort boundaries and the same
number of partitions, we perform a merge join, which merges overlapping
pairs of partitions.

### Sort-eliding merge-join for presorted (but unaligned) dataframes

Merge join is performed on the two sorted sides of the join, which
merges overlapping pairs of partitions. The partition boundaries of the
two sides of the join do not need to line up; partition overlap will be
determined in the scheduler and merge-join tasks will be emitted for
overlapping partitions.

## TODOs

- [x] Test coverage
- [x] Fix dispatch logic for adjacent partitions with intersecting
bounds (broadcast on matching bounds instead of naive zip).
- [x] Add support for sort-merge joins on sorted dataframes with
unaligned partition boundaries.
- [ ] Add support for non-primitive join keys.
- [ ] Benchmarking to validate + tweak the heuristics used by the query
planner to choose whether to use the sort-merge join.
  • Loading branch information
clarkzinzow authored Jan 31, 2024
1 parent 6cda37a commit 2eefca6
Show file tree
Hide file tree
Showing 34 changed files with 2,119 additions and 154 deletions.
5 changes: 5 additions & 0 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def set_execution_config(
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
sort_merge_join_sort_with_aligned_boundaries: bool | None = None,
sample_size_for_sort: int | None = None,
num_preview_rows: int | None = None,
parquet_target_filesize: int | None = None,
Expand All @@ -228,6 +229,9 @@ def set_execution_config(
fewer partitions. (Defaults to 512 MiB)
broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used.
Default is 10 MiB.
sort_merge_join_sort_with_aligned_boundaries: Whether to use a specialized algorithm for sorting both sides of a
sort-merge join such that they have aligned boundaries. This can lead to a faster merge-join at the cost of
more skewed sorted join inputs, increasing the risk of OOMs.
sample_size_for_sort: number of elements to sample from each partition when running sort,
Default is 20.
num_preview_rows: number of rows to when showing a dataframe preview,
Expand All @@ -245,6 +249,7 @@ def set_execution_config(
merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries,
sample_size_for_sort=sample_size_for_sort,
num_preview_rows=num_preview_rows,
parquet_target_filesize=parquet_target_filesize,
Expand Down
44 changes: 40 additions & 4 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,32 @@ class JoinType(Enum):
Args:
join_type: String representation of the join type. This is the same as the enum
attribute name, e.g. ``JoinType.from_join_type_str("Inner")`` would
attribute name (but snake-case), e.g. ``JoinType.from_join_type_str("inner")`` would
return ``JoinType.Inner``.
"""
...

class JoinStrategy(Enum):
"""
Join strategy (algorithm) to use.
"""

Hash: int
SortMerge: int
Broadcast: int

@staticmethod
def from_join_strategy_str(join_strategy: str) -> JoinStrategy:
"""
Create a JoinStrategy from its string representation.
Args:
join_strategy: String representation of the join strategy. This is the same as the enum
attribute name (but snake-case), e.g. ``JoinType.from_join_strategy_str("sort_merge")`` would
return ``JoinStrategy.SortMerge``.
"""
...

class CountMode(Enum):
"""
Supported count modes for Daft's count aggregation.
Expand Down Expand Up @@ -967,7 +988,10 @@ class PyTable:
def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyTable: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ...
def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyTable: ...
def join(self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyTable: ...
def hash_join(self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyTable: ...
def sort_merge_join(
self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr], is_sorted: bool
) -> PyTable: ...
def explode(self, to_explode: list[PyExpr]) -> PyTable: ...
def head(self, num: int) -> PyTable: ...
def sample_by_fraction(self, fraction: float, with_replacement: bool, seed: int | None) -> PyTable: ...
Expand Down Expand Up @@ -1022,7 +1046,10 @@ class PyMicroPartition:
def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyMicroPartition: ...
def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ...
def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyMicroPartition: ...
def join(self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyMicroPartition: ...
def hash_join(self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyMicroPartition: ...
def sort_merge_join(
self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr], is_sorted: bool
) -> PyMicroPartition: ...
def explode(self, to_explode: list[PyExpr]) -> PyMicroPartition: ...
def head(self, num: int) -> PyMicroPartition: ...
def sample_by_fraction(self, fraction: float, with_replacement: bool, seed: int | None) -> PyMicroPartition: ...
Expand Down Expand Up @@ -1090,6 +1117,7 @@ class PhysicalPlanScheduler:
"""

def num_partitions(self) -> int: ...
def partition_spec(self) -> PartitionSpec: ...
def to_partition_tasks(
self, psets: dict[str, list[PartitionT]], is_ray_runner: bool
) -> physical_plan.InProgressPhysicalPlan: ...
Expand Down Expand Up @@ -1129,7 +1157,12 @@ class LogicalPlanBuilder:
def sample(self, fraction: float, with_replacement: bool, seed: int | None) -> LogicalPlanBuilder: ...
def aggregate(self, agg_exprs: list[PyExpr], groupby_exprs: list[PyExpr]) -> LogicalPlanBuilder: ...
def join(
self, right: LogicalPlanBuilder, left_on: list[PyExpr], right_on: list[PyExpr], join_type: JoinType
self,
right: LogicalPlanBuilder,
left_on: list[PyExpr],
right_on: list[PyExpr],
join_type: JoinType,
strategy: JoinStrategy | None = None,
) -> LogicalPlanBuilder: ...
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
def table_write(
Expand All @@ -1151,6 +1184,7 @@ class PyDaftExecutionConfig:
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
sort_merge_join_sort_with_aligned_boundaries: bool | None = None,
sample_size_for_sort: int | None = None,
num_preview_rows: int | None = None,
parquet_target_filesize: int | None = None,
Expand All @@ -1166,6 +1200,8 @@ class PyDaftExecutionConfig:
@property
def broadcast_join_size_bytes_threshold(self) -> int: ...
@property
def sort_merge_join_sort_with_aligned_boundaries(self) -> bool: ...
@property
def sample_size_for_sort(self) -> int: ...
@property
def num_preview_rows(self) -> int: ...
Expand Down
17 changes: 15 additions & 2 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@
from daft.api_annotations import DataframePublicAPI
from daft.context import get_context
from daft.convert import InputListType
from daft.daft import FileFormat, IOConfig, JoinType, PartitionScheme, ResourceRequest
from daft.daft import (
FileFormat,
IOConfig,
JoinStrategy,
JoinType,
PartitionScheme,
ResourceRequest,
)
from daft.dataframe.preview import DataFramePreview
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
Expand Down Expand Up @@ -724,6 +731,7 @@ def join(
left_on: Optional[Union[List[ColumnInputType], ColumnInputType]] = None,
right_on: Optional[Union[List[ColumnInputType], ColumnInputType]] = None,
how: str = "inner",
strategy: Optional[str] = None,
) -> "DataFrame":
"""Column-wise join of the current DataFrame with an ``other`` DataFrame, similar to a SQL ``JOIN``
Expand All @@ -737,6 +745,8 @@ def join(
left_on (Optional[Union[List[ColumnInputType], ColumnInputType]], optional): key or keys to join on left DataFrame.. Defaults to None.
right_on (Optional[Union[List[ColumnInputType], ColumnInputType]], optional): key or keys to join on right DataFrame. Defaults to None.
how (str, optional): what type of join to performing, currently only `inner` is supported. Defaults to "inner".
strategy (Optional[str]): The join strategy (algorithm) to use; currently "hash", "sort_merge", "broadcast", and None are supported, where None
chooses the join strategy automatically during query optimization. The default is None.
Raises:
ValueError: if `on` is passed in and `left_on` or `right_on` is not None.
Expand All @@ -756,10 +766,13 @@ def join(
join_type = JoinType.from_join_type_str(how)
if join_type != JoinType.Inner:
raise ValueError(f"Only inner joins are currently supported, but got: {how}")
join_strategy = JoinStrategy.from_join_strategy_str(strategy) if strategy is not None else None

left_exprs = self.__column_input_to_expression(tuple(left_on) if isinstance(left_on, list) else (left_on,))
right_exprs = self.__column_input_to_expression(tuple(right_on) if isinstance(right_on, list) else (right_on,))
builder = self._builder.join(other._builder, left_on=left_exprs, right_on=right_exprs, how=join_type)
builder = self._builder.join(
other._builder, left_on=left_exprs, right_on=right_exprs, how=join_type, strategy=join_strategy
)
return DataFrame(builder)

@DataframePublicAPI
Expand Down
Loading

0 comments on commit 2eefca6

Please sign in to comment.