From 2eefca6fc78b8868b481536ace732c1150ec9755 Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Wed, 31 Jan 2024 14:20:18 -0800 Subject: [PATCH] [FEAT] [Join Optimizations] Add sort-merge join. (#1755) This PR adds a sort-merge join implementation as a new join strategy, where each side of the join is sorted on the join keys, and then the sorted tables are merged. The sort-merge join strategy is chosen automatically by the query planner if it is expected to be faster than the hash join and the broadcast join. Similar to Spark's ability to specify a join algorithm hint, this PR also exposes a new optional `strategy` arg for `df.join()`, which allows users (and our Python-level tests) to manually specify a join algorithm; currently `"hash"`, `"sort_merge"`, and `"broadcast"` are supported, with the default `None` resulting in the query planner choosing a join algorithm automatically. ```python df = left.join(right, on="foo", strategy="sort_merge") ``` Closes #1776 ## Query Planning The query planner chooses the sort-merge join as its join strategy if the larger side of the join is range-partitioned, or if the smaller side of the join is range-partitioned and the larger side is not partitioned. In the future, we will want to do a sort-merge join: 1. If a downstream operation requires the table to be sorted on the join key. 2. If neither sides of the join are partitioned AND we determine that the sort-merge join is faster on unpartitioned data the the hash join (pending benchmarking). **NOTE:** We currently only support sort-merge joins on primitive join keys, see TODO below. ## Query Scheduling ### Full sort-merge join All partitions for both sides of the join are materialized, upon which we calculate sort boundaries on samples from both sides of the join. These combined sort boundaries are used to sort each side of the join. Once each side is sorted with the same sort boundaries and the same number of partitions, we perform a merge join, which merges overlapping pairs of partitions. ### Sort-eliding merge-join for presorted (but unaligned) dataframes Merge join is performed on the two sorted sides of the join, which merges overlapping pairs of partitions. The partition boundaries of the two sides of the join do not need to line up; partition overlap will be determined in the scheduler and merge-join tasks will be emitted for overlapping partitions. ## TODOs - [x] Test coverage - [x] Fix dispatch logic for adjacent partitions with intersecting bounds (broadcast on matching bounds instead of naive zip). - [x] Add support for sort-merge joins on sorted dataframes with unaligned partition boundaries. - [ ] Add support for non-primitive join keys. - [ ] Benchmarking to validate + tweak the heuristics used by the query planner to choose whether to use the sort-merge join. --- daft/context.py | 5 + daft/daft.pyi | 44 +- daft/dataframe/dataframe.py | 17 +- daft/execution/execution_step.py | 96 ++- daft/execution/physical_plan.py | 658 +++++++++++++++++- daft/execution/rust_physical_plan_shim.py | 42 ++ daft/logical/builder.py | 3 + .../plan_scheduler/physical_plan_scheduler.py | 4 + daft/runners/partitioning.py | 78 +++ daft/table/micropartition.py | 31 +- daft/table/table.py | 29 +- src/common/daft-config/src/lib.rs | 2 + src/common/daft-config/src/python.rs | 12 + src/daft-core/src/kernels/search_sorted.rs | 33 + src/daft-micropartition/src/ops/join.rs | 62 +- src/daft-micropartition/src/python.rs | 27 +- src/daft-plan/src/builder.rs | 7 +- src/daft-plan/src/join.rs | 64 +- src/daft-plan/src/lib.rs | 3 +- src/daft-plan/src/logical_ops/join.rs | 11 +- src/daft-plan/src/logical_plan.rs | 2 +- .../optimization/rules/push_down_filter.rs | 9 +- src/daft-plan/src/physical_ops/mod.rs | 2 + .../src/physical_ops/sort_merge_join.rs | 44 ++ src/daft-plan/src/physical_plan.rs | 68 +- src/daft-plan/src/planner.rs | 187 +++-- src/daft-table/src/lib.rs | 1 + src/daft-table/src/ops/joins/merge_join.rs | 243 +++++++ src/daft-table/src/ops/joins/mod.rs | 62 +- src/daft-table/src/python.rs | 27 +- tests/conftest.py | 17 + tests/cookbook/test_joins.py | 30 +- tests/dataframe/test_joins.py | 287 +++++++- tests/table/test_joins.py | 66 +- 34 files changed, 2119 insertions(+), 154 deletions(-) create mode 100644 src/daft-plan/src/physical_ops/sort_merge_join.rs create mode 100644 src/daft-table/src/ops/joins/merge_join.rs diff --git a/daft/context.py b/daft/context.py index 2d512d056c..ba4423074c 100644 --- a/daft/context.py +++ b/daft/context.py @@ -206,6 +206,7 @@ def set_execution_config( merge_scan_tasks_min_size_bytes: int | None = None, merge_scan_tasks_max_size_bytes: int | None = None, broadcast_join_size_bytes_threshold: int | None = None, + sort_merge_join_sort_with_aligned_boundaries: bool | None = None, sample_size_for_sort: int | None = None, num_preview_rows: int | None = None, parquet_target_filesize: int | None = None, @@ -228,6 +229,9 @@ def set_execution_config( fewer partitions. (Defaults to 512 MiB) broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used. Default is 10 MiB. + sort_merge_join_sort_with_aligned_boundaries: Whether to use a specialized algorithm for sorting both sides of a + sort-merge join such that they have aligned boundaries. This can lead to a faster merge-join at the cost of + more skewed sorted join inputs, increasing the risk of OOMs. sample_size_for_sort: number of elements to sample from each partition when running sort, Default is 20. num_preview_rows: number of rows to when showing a dataframe preview, @@ -245,6 +249,7 @@ def set_execution_config( merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes, merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes, broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold, + sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries, sample_size_for_sort=sample_size_for_sort, num_preview_rows=num_preview_rows, parquet_target_filesize=parquet_target_filesize, diff --git a/daft/daft.pyi b/daft/daft.pyi index 89348cc66f..2355c81b11 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -89,11 +89,32 @@ class JoinType(Enum): Args: join_type: String representation of the join type. This is the same as the enum - attribute name, e.g. ``JoinType.from_join_type_str("Inner")`` would + attribute name (but snake-case), e.g. ``JoinType.from_join_type_str("inner")`` would return ``JoinType.Inner``. """ ... +class JoinStrategy(Enum): + """ + Join strategy (algorithm) to use. + """ + + Hash: int + SortMerge: int + Broadcast: int + + @staticmethod + def from_join_strategy_str(join_strategy: str) -> JoinStrategy: + """ + Create a JoinStrategy from its string representation. + + Args: + join_strategy: String representation of the join strategy. This is the same as the enum + attribute name (but snake-case), e.g. ``JoinType.from_join_strategy_str("sort_merge")`` would + return ``JoinStrategy.SortMerge``. + """ + ... + class CountMode(Enum): """ Supported count modes for Daft's count aggregation. @@ -967,7 +988,10 @@ class PyTable: def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyTable: ... def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ... def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyTable: ... - def join(self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyTable: ... + def hash_join(self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyTable: ... + def sort_merge_join( + self, right: PyTable, left_on: list[PyExpr], right_on: list[PyExpr], is_sorted: bool + ) -> PyTable: ... def explode(self, to_explode: list[PyExpr]) -> PyTable: ... def head(self, num: int) -> PyTable: ... def sample_by_fraction(self, fraction: float, with_replacement: bool, seed: int | None) -> PyTable: ... @@ -1022,7 +1046,10 @@ class PyMicroPartition: def sort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PyMicroPartition: ... def argsort(self, sort_keys: list[PyExpr], descending: list[bool]) -> PySeries: ... def agg(self, to_agg: list[PyExpr], group_by: list[PyExpr]) -> PyMicroPartition: ... - def join(self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyMicroPartition: ... + def hash_join(self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr]) -> PyMicroPartition: ... + def sort_merge_join( + self, right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr], is_sorted: bool + ) -> PyMicroPartition: ... def explode(self, to_explode: list[PyExpr]) -> PyMicroPartition: ... def head(self, num: int) -> PyMicroPartition: ... def sample_by_fraction(self, fraction: float, with_replacement: bool, seed: int | None) -> PyMicroPartition: ... @@ -1090,6 +1117,7 @@ class PhysicalPlanScheduler: """ def num_partitions(self) -> int: ... + def partition_spec(self) -> PartitionSpec: ... def to_partition_tasks( self, psets: dict[str, list[PartitionT]], is_ray_runner: bool ) -> physical_plan.InProgressPhysicalPlan: ... @@ -1129,7 +1157,12 @@ class LogicalPlanBuilder: def sample(self, fraction: float, with_replacement: bool, seed: int | None) -> LogicalPlanBuilder: ... def aggregate(self, agg_exprs: list[PyExpr], groupby_exprs: list[PyExpr]) -> LogicalPlanBuilder: ... def join( - self, right: LogicalPlanBuilder, left_on: list[PyExpr], right_on: list[PyExpr], join_type: JoinType + self, + right: LogicalPlanBuilder, + left_on: list[PyExpr], + right_on: list[PyExpr], + join_type: JoinType, + strategy: JoinStrategy | None = None, ) -> LogicalPlanBuilder: ... def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ... def table_write( @@ -1151,6 +1184,7 @@ class PyDaftExecutionConfig: merge_scan_tasks_min_size_bytes: int | None = None, merge_scan_tasks_max_size_bytes: int | None = None, broadcast_join_size_bytes_threshold: int | None = None, + sort_merge_join_sort_with_aligned_boundaries: bool | None = None, sample_size_for_sort: int | None = None, num_preview_rows: int | None = None, parquet_target_filesize: int | None = None, @@ -1166,6 +1200,8 @@ class PyDaftExecutionConfig: @property def broadcast_join_size_bytes_threshold(self) -> int: ... @property + def sort_merge_join_sort_with_aligned_boundaries(self) -> bool: ... + @property def sample_size_for_sort(self) -> int: ... @property def num_preview_rows(self) -> int: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 6060c586f9..2da78403ff 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -25,7 +25,14 @@ from daft.api_annotations import DataframePublicAPI from daft.context import get_context from daft.convert import InputListType -from daft.daft import FileFormat, IOConfig, JoinType, PartitionScheme, ResourceRequest +from daft.daft import ( + FileFormat, + IOConfig, + JoinStrategy, + JoinType, + PartitionScheme, + ResourceRequest, +) from daft.dataframe.preview import DataFramePreview from daft.datatype import DataType from daft.errors import ExpressionTypeError @@ -724,6 +731,7 @@ def join( left_on: Optional[Union[List[ColumnInputType], ColumnInputType]] = None, right_on: Optional[Union[List[ColumnInputType], ColumnInputType]] = None, how: str = "inner", + strategy: Optional[str] = None, ) -> "DataFrame": """Column-wise join of the current DataFrame with an ``other`` DataFrame, similar to a SQL ``JOIN`` @@ -737,6 +745,8 @@ def join( left_on (Optional[Union[List[ColumnInputType], ColumnInputType]], optional): key or keys to join on left DataFrame.. Defaults to None. right_on (Optional[Union[List[ColumnInputType], ColumnInputType]], optional): key or keys to join on right DataFrame. Defaults to None. how (str, optional): what type of join to performing, currently only `inner` is supported. Defaults to "inner". + strategy (Optional[str]): The join strategy (algorithm) to use; currently "hash", "sort_merge", "broadcast", and None are supported, where None + chooses the join strategy automatically during query optimization. The default is None. Raises: ValueError: if `on` is passed in and `left_on` or `right_on` is not None. @@ -756,10 +766,13 @@ def join( join_type = JoinType.from_join_type_str(how) if join_type != JoinType.Inner: raise ValueError(f"Only inner joins are currently supported, but got: {how}") + join_strategy = JoinStrategy.from_join_strategy_str(strategy) if strategy is not None else None left_exprs = self.__column_input_to_expression(tuple(left_on) if isinstance(left_on, list) else (left_on,)) right_exprs = self.__column_input_to_expression(tuple(right_on) if isinstance(right_on, list) else (right_on,)) - builder = self._builder.join(other._builder, left_on=left_exprs, right_on=right_exprs, how=join_type) + builder = self._builder.join( + other._builder, left_on=left_exprs, right_on=right_exprs, how=join_type, strategy=join_strategy + ) return DataFrame(builder) @DataframePublicAPI diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 09ee607775..072584ec6c 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -27,6 +27,7 @@ from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema from daft.runners.partitioning import ( + Boundaries, MaterializedResult, PartialPartitionMetadata, PartitionMetadata, @@ -54,6 +55,7 @@ class PartitionTask(Generic[PartitionT]): resource_request: ResourceRequest num_results: int stage_id: int + partial_metadatas: list[PartialPartitionMetadata] _id: int = field(default_factory=lambda: next(ID_GEN)) def id(self) -> str: @@ -71,6 +73,10 @@ def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: """Set the result of this Task. For use by the Task executor.""" raise NotImplementedError + def is_empty(self) -> bool: + """Whether this partition task is guaranteed to result in an empty partition.""" + return len(self.partial_metadatas) > 0 and all(meta.num_rows == 0 for meta in self.partial_metadatas) + def __str__(self) -> str: return ( f"{self.id()}\n" @@ -116,6 +122,10 @@ def add_instruction( self.num_results = instruction.num_outputs() return self + def is_empty(self) -> bool: + """Whether this partition task is guaranteed to result in an empty partition.""" + return len(self.partial_metadatas) > 0 and all(meta.num_rows == 0 for meta in self.partial_metadatas) + def finalize_partition_task_single_output(self, stage_id: int) -> SingleOutputPartitionTask[PartitionT]: """Create a SingleOutputPartitionTask from this PartitionTaskBuilder. @@ -135,6 +145,7 @@ def finalize_partition_task_single_output(self, stage_id: int) -> SingleOutputPa instructions=self.instructions, num_results=1, resource_request=resource_request_final_cpu, + partial_metadatas=self.partial_metadatas, ) def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPartitionTask[PartitionT]: @@ -154,6 +165,7 @@ def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPart instructions=self.instructions, num_results=self.num_results, resource_request=resource_request_final_cpu, + partial_metadatas=self.partial_metadatas, ) def __str__(self) -> str: @@ -198,7 +210,8 @@ def partition_metadata(self) -> PartitionMetadata: (Avoids retrieving the actual partition itself if possible.) """ - return self.result().metadata() + [partial_metadata] = self.partial_metadatas + return self.result().metadata().merge_with_partial(partial_metadata) def vpartition(self) -> MicroPartition: """Get the raw vPartition of the result.""" @@ -243,7 +256,10 @@ def partition_metadatas(self) -> list[PartitionMetadata]: (Avoids retrieving the actual partition itself if possible.) """ assert self._results is not None - return [result.metadata() for result in self._results] + return [ + result.metadata().merge_with_partial(partial_metadata) + for result, partial_metadata in zip(self._results, self.partial_metadatas) + ] def vpartition(self, index: int) -> MicroPartition: """Get the raw vPartition of the result.""" @@ -461,11 +477,12 @@ def _filter(self, inputs: list[MicroPartition]) -> list[MicroPartition]: return [input.filter(self.predicate)] def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: - # Can't derive anything. + [input_meta] = input_metadatas return [ PartialPartitionMetadata( num_rows=None, size_bytes=None, + boundaries=input_meta.boundaries, ) ] @@ -483,14 +500,38 @@ def _project(self, inputs: list[MicroPartition]) -> list[MicroPartition]: def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: [input_meta] = input_metadatas + boundaries = input_meta.boundaries + if boundaries is not None: + boundaries = _prune_boundaries(boundaries, self.projection) return [ PartialPartitionMetadata( num_rows=input_meta.num_rows, size_bytes=None, + boundaries=boundaries, ) ] +def _prune_boundaries(boundaries: Boundaries, projection: ExpressionsProjection) -> Boundaries | None: + """ + If projection expression is a nontrivial computation (i.e. not a direct col() reference and not an alias) on top of a boundary + expression, then invalidate the boundary. + """ + proj_all_names = projection.to_name_set() + proj_names_needing_compute = proj_all_names - projection.input_mapping().keys() + for i, e in enumerate(boundaries.sort_by): + if e.name() in proj_names_needing_compute: + # Found a sort expression that is no longer valid, so we invalidate that sort expression and all that follow it. + sort_by = boundaries.sort_by[:i] + if not sort_by: + return None + boundaries_ = boundaries.bounds.eval_expression_list( + ExpressionsProjection([col(e.name()) for e in sort_by]) + ) + return Boundaries(sort_by, boundaries_) + return boundaries + + @dataclass(frozen=True) class LocalCount(SingleOutputInstruction): schema: Schema @@ -530,6 +571,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) PartialPartitionMetadata( num_rows=(min(self.limit, input_meta.num_rows) if input_meta.num_rows is not None else None), size_bytes=None, + boundaries=input_meta.boundaries, ) ] @@ -616,16 +658,16 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) @dataclass(frozen=True) -class Join(SingleOutputInstruction): +class HashJoin(SingleOutputInstruction): left_on: ExpressionsProjection right_on: ExpressionsProjection how: JoinType is_swapped: bool def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - return self._join(inputs) + return self._hash_join(inputs) - def _join(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + def _hash_join(self, inputs: list[MicroPartition]) -> list[MicroPartition]: # All inputs except for the last are the left side of the join, in order to support left-broadcasted joins. *lefts, right = inputs if len(lefts) > 1: @@ -637,7 +679,7 @@ def _join(self, inputs: list[MicroPartition]) -> list[MicroPartition]: # Swap left/right back. # We don't need to swap left_on and right_on since those were never swapped in the first place. left, right = right, left - result = left.join( + result = left.hash_join( right, left_on=self.left_on, right_on=self.right_on, @@ -655,6 +697,43 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) ] +@dataclass(frozen=True) +class MergeJoin(SingleOutputInstruction): + left_on: ExpressionsProjection + right_on: ExpressionsProjection + how: JoinType + preserve_left_bounds: bool + + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + return self._join(inputs) + + def _join(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + left, right = inputs + result = left.sort_merge_join( + right, + left_on=self.left_on, + right_on=self.right_on, + how=self.how, + is_sorted=True, + ) + return [result] + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + [left_meta, right_meta] = input_metadatas + # If the boundaries of the left and right partitions don't intersect, then the merge-join will result in an empty partition. + if left_meta.boundaries is None or right_meta.boundaries is None: + is_nonempty = True + else: + is_nonempty = left_meta.boundaries.intersects(right_meta.boundaries) + return [ + PartialPartitionMetadata( + num_rows=None if is_nonempty else 0, + size_bytes=None, + boundaries=left_meta.boundaries if self.preserve_left_bounds else right_meta.boundaries, + ) + ] + + class ReduceInstruction(SingleOutputInstruction): ... @@ -682,6 +761,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) class ReduceMergeAndSort(ReduceInstruction): sort_by: ExpressionsProjection descending: list[bool] + bounds: MicroPartition def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: return self._reduce_merge_and_sort(inputs) @@ -697,6 +777,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) PartialPartitionMetadata( num_rows=sum(input_rows) if all(_ is not None for _ in input_rows) else None, size_bytes=sum(input_sizes) if all(_ is not None for _ in input_sizes) else None, + boundaries=Boundaries(list(self.sort_by), self.bounds), ) ] @@ -834,6 +915,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) PartialPartitionMetadata( num_rows=num_rows, size_bytes=None, + boundaries=input_meta.boundaries, ) ) diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index fb2788c8ad..61982a880a 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -13,11 +13,13 @@ from __future__ import annotations +import collections +import itertools import logging import math import pathlib from collections import deque -from typing import Generator, Iterator, TypeVar, Union +from typing import Generator, Generic, Iterable, Iterator, TypeVar, Union from daft.context import get_context from daft.daft import ( @@ -44,6 +46,7 @@ PartialPartitionMetadata, PartitionT, ) +from daft.table.micropartition import MicroPartition logger = logging.getLogger(__name__) @@ -194,7 +197,7 @@ def hash_join( right_on: ExpressionsProjection, how: JoinType, ) -> InProgressPhysicalPlan[PartitionT]: - """Pairwise join the partitions from `left_child_plan` and `right_child_plan` together.""" + """Hash-based pairwise join the partitions from `left_child_plan` and `right_child_plan` together.""" # Materialize the steps from the left and right sources to get partitions. # As the materializations complete, emit new steps to join each left and right partition. @@ -230,7 +233,7 @@ def hash_join( partial_metadatas=[next_left.partition_metadata(), next_right.partition_metadata()], resource_request=ResourceRequest(memory_bytes=size_bytes), ).add_instruction( - instruction=execution_step.Join( + instruction=execution_step.HashJoin( left_on=left_on, right_on=right_on, how=how, @@ -273,7 +276,7 @@ def hash_join( return -def _create_join_step( +def _create_broadcast_join_step( broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]], receiver_part: SingleOutputPartitionTask[PartitionT], left_on: ExpressionsProjection, @@ -320,7 +323,7 @@ def _create_join_step( partial_metadatas=list(broadcaster_partition_metadatas + [receiver_part.partition_metadata()]), resource_request=ResourceRequest(memory_bytes=size_bytes), ).add_instruction( - instruction=execution_step.Join( + instruction=execution_step.HashJoin( left_on=left_on, right_on=right_on, how=how, @@ -376,7 +379,7 @@ def broadcast_join( # Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop. while receiver_requests and receiver_requests[0].done(): receiver_part = receiver_requests.popleft() - yield _create_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped) + yield _create_broadcast_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped) # Execute single child step to pull in more input partitions. try: @@ -396,6 +399,628 @@ def broadcast_join( return +class MergeJoinTaskTracker(Generic[PartitionT]): + """ + Tracks merge-join tasks for each larger-side partition. + + Merge-join tasks are added to the tracker, and the tracker handles empty tasks, finalizing PartitionTaskBuilders, + determining whether tasks are ready to be executed, checking whether tasks are done, and deciding whether a coalesce + is needed. + """ + + def __init__(self, stage_id: int): + # Merge-join tasks that have not yet been finalized or yielded to the runner. We don't finalize a merge-join + # task until we have at least 2 non-empty merge-join tasks, at which point this task will be popped from + # _task_staging, finalized, and put into _finalized_tasks. + self._task_staging: dict[str, PartitionTaskBuilder[PartitionT]] = {} + # Merge-join tasks that have been finalized, but not yet yielded to the runner. + self._finalized_tasks: collections.defaultdict[ + str, deque[SingleOutputPartitionTask[PartitionT]] + ] = collections.defaultdict(deque) + # Merge-join tasks that have been yielded to the runner, and still need to be coalesced. + self._uncoalesced_tasks: collections.defaultdict[ + str, deque[SingleOutputPartitionTask[PartitionT]] + ] = collections.defaultdict(deque) + # Larger-side partitions that have been finalized, i.e. we're guaranteed that no more smaller-side partitions + # will be added to the tracker for this partition. + self._finalized: dict[str, bool] = {} + self._stage_id = stage_id + + def add_task(self, part_id: str, task: PartitionTaskBuilder[PartitionT]) -> None: + """ + Add a merge-join task to the tracker for the provided larger-side partition. + + This task needs to be unfinalized, i.e. a PartitionTaskBuilder. + """ + # If no merge-join tasks have been added to the tracker yet for this partition, or we have an empty task in + # staging, add the unfinalized merge-join task to staging. + if not self._is_contained(part_id) or ( + part_id in self._task_staging and self._task_staging[part_id].is_empty() + ): + self._task_staging[part_id] = task + # Otherwise, we have at least 2 (probably) non-empty merge-join tasks, so we finalize the new task and add it + # to _finalized_tasks. If the new task is empty, then we drop it (we already have at least one task for this + # partition, so no use in keeping an additional empty task around). + elif not task.is_empty(): + # If we have a task in staging, we know from the first if statement that it's non-empty, so we finalize it + # and add it to _finalized_tasks. + if part_id in self._task_staging: + self._finalized_tasks[part_id].append( + self._task_staging.pop(part_id).finalize_partition_task_single_output(self._stage_id) + ) + self._finalized_tasks[part_id].append(task.finalize_partition_task_single_output(self._stage_id)) + + def finalize(self, part_id: str) -> None: + """ + Indicates to the tracker that we are done adding merge-join tasks for this partition. + """ + # All finalized tasks should have been yielded before the tracker.finalize() call. + finalized_tasks = self._finalized_tasks.pop(part_id, deque()) + assert len(finalized_tasks) == 0 + + self._finalized[part_id] = True + + def yield_ready( + self, part_id: str + ) -> Iterator[SingleOutputPartitionTask[PartitionT] | PartitionTaskBuilder[PartitionT]]: + """ + Returns an iterator of all tasks for this partition that are ready for execution. Each merge-join task will be + yielded once, even across multiple calls. + """ + assert self._is_contained(part_id) + if part_id in self._finalized_tasks: + # Yield the finalized tasks and add them to the uncoalesced queue. + while self._finalized_tasks[part_id]: + task = self._finalized_tasks[part_id].popleft() + yield task + self._uncoalesced_tasks[part_id].append(task) + elif self._finalized.get(part_id, False) and part_id in self._task_staging: + # If the tracker has been finalized for this partition, we can yield unfinalized tasks directly from + # staging since no future tasks will be added. + yield self._task_staging.pop(part_id) + + def pop_uncoalesced(self, part_id: str) -> deque[SingleOutputPartitionTask[PartitionT]] | None: + """ + Returns all tasks for this partition that need to be coalesced. If this partition only involved a single + merge-join task (i.e. we don't need to coalesce), this this function will return None. + + NOTE: tracker.finalize(part_id) must be called before this function. + """ + assert self._finalized[part_id] + return self._uncoalesced_tasks.pop(part_id, None) + + def all_tasks_done_for_partition(self, part_id: str) -> bool: + """ + Return whether all merge-join tasks for this partition are done. + """ + assert self._is_contained(part_id) + if part_id in self._task_staging: + # Unfinalized tasks are trivially "done". + return True + return all( + task.done() + for task in itertools.chain( + self._finalized_tasks.get(part_id, deque()), self._uncoalesced_tasks.get(part_id, deque()) + ) + ) + + def all_tasks_done(self) -> bool: + """ + Return whether all merge-join tasks for all partitions are done. + """ + return all( + self.all_tasks_done_for_partition(part_id) + for part_id in itertools.chain( + self._uncoalesced_tasks.keys(), self._finalized_tasks.keys(), self._task_staging.keys() + ) + ) + + def _is_contained(self, part_id: str) -> bool: + """ + Return whether the provided partition is being tracked by this tracker. + """ + return part_id in self._task_staging or part_id in self._finalized_tasks or part_id in self._uncoalesced_tasks + + +def _emit_merge_joins_on_window( + next_part: SingleOutputPartitionTask[PartitionT], + other_window: deque[SingleOutputPartitionTask[PartitionT]], + merge_join_task_tracker: MergeJoinTaskTracker[PartitionT], + flipped: bool, + next_is_larger: bool, + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + how: JoinType, +) -> Iterator[PartitionTaskBuilder[PartitionT] | PartitionTask[PartitionT]]: + """ + Emits merge-join steps of next_part with each partition in other_window. + """ + # Emit a merge-join step for all partitions in the other window that intersect with this new partition. + for other_next_part in other_window: + memory_bytes = _memory_bytes_for_merge(next_part, other_next_part) + inputs = [next_part.partition(), other_next_part.partition()] + partial_metadatas = [ + next_part.partition_metadata().downcast_to_partial(), + other_next_part.partition_metadata().downcast_to_partial(), + ] + # If next, other are flipped (right, left partitions), flip them back. + if flipped: + inputs = list(reversed(inputs)) + partial_metadatas = list(reversed(partial_metadatas)) + join_task = PartitionTaskBuilder[PartitionT]( + inputs=inputs, + partial_metadatas=partial_metadatas, + resource_request=ResourceRequest(memory_bytes=memory_bytes), + ).add_instruction( + instruction=execution_step.MergeJoin( + left_on=left_on, + right_on=right_on, + how=how, + preserve_left_bounds=not flipped, + ) + ) + part_id = next_part.id() if next_is_larger else other_next_part.id() + # Add to new merge-join step to tracked steps for this larger-side partition, and possibly start finalizing + + # emitting non-empty join steps if there are now more than one. + merge_join_task_tracker.add_task(part_id, join_task) + yield from merge_join_task_tracker.yield_ready(part_id) + + +def _memory_bytes_for_merge( + next_left: SingleOutputPartitionTask[PartitionT], next_right: SingleOutputPartitionTask[PartitionT] +) -> int | None: + # Calculate memory request for merge task. + left_size_bytes = next_left.partition_metadata().size_bytes + right_size_bytes = next_right.partition_metadata().size_bytes + if left_size_bytes is None and right_size_bytes is None: + size_bytes = None + elif left_size_bytes is None and right_size_bytes is not None: + # Use 2x the right side as the memory request, assuming that left and right side are ~ the same size. + size_bytes = 2 * right_size_bytes + elif right_size_bytes is None and left_size_bytes is not None: + # Use 2x the left side as the memory request, assuming that left and right side are ~ the same size. + size_bytes = 2 * left_size_bytes + elif left_size_bytes is not None and right_size_bytes is not None: + size_bytes = left_size_bytes + right_size_bytes + return size_bytes + + +def merge_join_sorted( + left_plan: InProgressPhysicalPlan[PartitionT], + right_plan: InProgressPhysicalPlan[PartitionT], + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + how: JoinType, + left_is_larger: bool, +) -> InProgressPhysicalPlan[PartitionT]: + """ + Merge the sorted partitions from `left_plan` and `right_plan` together. + + This assumes that `left_plan` and `right_plan` are both sorted on the join key(s), although with potentially + different range partitionings (partition boundaries). + """ + + # Large vs. smaller side of join. + larger_plan = left_plan if left_is_larger else right_plan + smaller_plan = right_plan if left_is_larger else left_plan + + stage_id = next(stage_id_counter) + + # In-progress tasks for larger side of join. + larger_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque() + # In-progress tasks for smaller side of join. + smaller_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque() + # Materialized partitions for larger side of join; a larger-side partition isn't dropped until we've emitted all + # join steps with smaller-side partitions that may overlap with it.. + larger_window: deque[SingleOutputPartitionTask[PartitionT]] = deque() + # Materialized partitions for smaller side of join; a smaller-side partition isn't dropped until the most recent + # larger-side materialized partition has a higher upper bound, which suggests that this smaller-side partition won't + # be able to intersect with any future larger-side partitions. + smaller_window: deque[SingleOutputPartitionTask[PartitionT]] = deque() + # Tracks merge-join partition tasks emitted for each partition on the larger side of the join. + # Once all merge-join tasks are done, the corresponding output partitions will be coalesced together. + # If only a single merge-join task is emitted for a larger-side partition, it will be an unfinalized + # PartitionTaskBuilder, the coalescing step will be skipped, and this merge-join task will be yielded without + # finalizing in order to allow fusion with downstream tasks; otherwise, the tracker will contain finalized + # PartitionTasks. + merge_join_task_tracker: MergeJoinTaskTracker[PartitionT] = MergeJoinTaskTracker(stage_id) + + yield_smaller = True + smaller_done = False + larger_done = False + + # As partitions materialize from either side of the join, emit new merge-join steps to join overlapping partitions + # together. + while True: + # Emit merge-join steps on newly completed partitions from the smaller side of the join with a window of + # (possibly) intersecting partitions from the larger side. + while smaller_requests and smaller_requests[0].done(): + next_part = smaller_requests.popleft() + yield from _emit_merge_joins_on_window( + next_part, + larger_window, + merge_join_task_tracker, + left_is_larger, + False, + left_on, + right_on, + how, + ) + smaller_window.append(next_part) + # Emit merge-join steps on newly completed partitions from the larger side of the join with a window of + # (possibly) intersecting partitions from the smaller side. + while larger_requests and larger_requests[0].done(): + next_part = larger_requests.popleft() + yield from _emit_merge_joins_on_window( + next_part, + smaller_window, + merge_join_task_tracker, + not left_is_larger, + True, + left_on, + right_on, + how, + ) + larger_window.append(next_part) + # Remove prefix of smaller window that's under the high water mark set by this new larger-side partition, + # since this prefix won't be able to match any future partitions on the smaller side of the join. + while ( + # We always leave at least one partition in the smaller-side window in case we need to yield an empty + # merge-join step for a future larger-side partition. + len(smaller_window) > (1 if larger_requests else 0) + and larger_window + and _is_strictly_bounded_above_by(smaller_window[0], larger_window[-1]) + ): + smaller_window.popleft() + # For each partition we remove from the larger window, we launch a coalesce task over all output partitions + # that correspond to that larger partition. + # This loop also removes the prefix of larger window that's under the high water mark set by the smaller window, + # since this prefix won't be able to match any future partitions on the smaller side. + while ( + # Must be a larger-side partition whose outputs need finalizing. + larger_window + and ( + # Larger-side partition is bounded above by the most recent smaller-side partition, which means that no + # future smaller-side partition can intersect with this larger-side partition, allowing us to finalize + # the merge-join steps for the larger-side partition. + (smaller_window and _is_strictly_bounded_above_by(larger_window[0], smaller_window[-1])) + # No more smaller partitions left, so we should launch coalesce tasks for all remaining + # larger-side partitions. + or smaller_done + ) + and ( + # Only finalize merge-join tasks for larger-side partition if all outputs are done OR there's only a + # single finalized output (in which case we yield and unfinalized merge-join task to allow downstream + # fusion with it). + merge_join_task_tracker.all_tasks_done_for_partition(larger_window[0].id()) + ) + ): + done_larger_part = larger_window.popleft() + part_id = done_larger_part.id() + # Indicate to merge-join task tracker that no more merge-join tasks will be added for this partition. + merge_join_task_tracker.finalize(part_id) + # Yield any merge-join tasks that are now ready after finalizing the tracking for this partition (i.e. if + # there was only a single merge-join task added to the tracker for this partition, it will now be yielded + # here). + yield from merge_join_task_tracker.yield_ready(part_id) + # Get merge-join tasks that need to be coalesced. + tasks = merge_join_task_tracker.pop_uncoalesced(part_id) + if tasks is None: + # Only one output partition, so no coalesce needed. + continue + # At least two (probably non-empty) merge-join tasks for this group, so need to coalesce. + # NOTE: We guarantee in _emit_merge_joins_on_window that any group containing 2 or more partition tasks + # will only contain non-guaranteed-empty partitions; i.e., we'll need to execute a task to determine if + # they actually are empty, so we just issue the coalesce task. + # TODO(Clark): Elide coalescing by emitting a single merge-join task per larger-side partition, including as + # input all intersecting partitions from the smaller side of the join. + size_bytes = _memory_bytes_for_coalesce(tasks) + coalesce_task = PartitionTaskBuilder[PartitionT]( + inputs=[task.partition() for task in tasks], + partial_metadatas=[task.partition_metadata() for task in tasks], + resource_request=ResourceRequest(memory_bytes=size_bytes), + ).add_instruction( + instruction=execution_step.ReduceMerge(), + ) + yield coalesce_task + + # Exhausted all ready inputs; execute a single child step to get more join inputs. + # Choose whether to execute from smaller child or larger child (whichever one is furthest behind). + num_smaller_in_flight = len(smaller_requests) + len(smaller_window) + num_larger_in_flight = len(larger_requests) + len(larger_window) + if smaller_done or larger_done or num_smaller_in_flight == num_larger_in_flight: + # Both plans have progressed equally (or the last yielded side is done); alternate between the two plans + # to avoid starving either one. + yield_smaller = not yield_smaller + next_plan, next_requests = ( + (smaller_plan, smaller_requests) if yield_smaller else (larger_plan, larger_requests) + ) + elif num_smaller_in_flight < num_larger_in_flight: + # Larger side of join is further along than the smaller side, so pull from the smaller side next. + next_plan, next_requests = smaller_plan, smaller_requests + yield_smaller = True + else: + # Smaller side of join is further along than the larger side, so pull from the larger side next. + next_plan, next_requests = larger_plan, larger_requests + yield_smaller = False + + # Pull from the chosen side of the join. + try: + step = next(next_plan) + if isinstance(step, PartitionTaskBuilder): + step = step.finalize_partition_task_single_output(stage_id=stage_id) + next_requests.append(step) + yield step + + except StopIteration: + # We've exhausted one of the sides of the join. + # If we have active tasks for either side of the join that completed while dispatching intermediate work, + # we continue with another loop so we can process those newly ready inputs. + if (smaller_requests and smaller_requests[0].done()) or (larger_requests and larger_requests[0].done()): + continue + # If we have active tasks for either side of the join that aren't done, tell runner that we're blocked on inputs. + elif smaller_requests or larger_requests: + logger.debug( + "merge join blocked on completion of sources.\n Left sources: %s\nRight sources: %s", + larger_requests if left_is_larger else smaller_requests, + smaller_requests if left_is_larger else larger_requests, + ) + yield None + # If we just exhausted small side of join, set smaller done flag. + elif yield_smaller and not smaller_done: + smaller_done = True + # If we just exhausted larger side of join, set larger done flag. + elif not yield_smaller and not larger_done: + larger_done = True + # We might still be waiting for some merge-join tasks to complete whose output we still need + # to coalesce. + elif not merge_join_task_tracker.all_tasks_done(): + logger.debug( + "merge join blocked on completion of merge join tasks (pre-coalesce).\nMerge-join tasks: %s", + list(merge_join_task_tracker._finalized_tasks.values()), + ) + yield None + # Otherwise, all join inputs are done and all merge-join tasks are done, so we are entirely done emitting + # merge join work. + else: + return + + +def _is_strictly_bounded_above_by( + lower_part: SingleOutputPartitionTask[PartitionT], upper_part: SingleOutputPartitionTask[PartitionT] +) -> bool: + """ + Returns whether lower_part is strictly bounded above by upper part; i.e., whether lower_part's upper bound is + strictly less than upper_part's upper bound. + """ + lower_boundaries = lower_part.partition_metadata().boundaries + upper_boundaries = upper_part.partition_metadata().boundaries + assert lower_boundaries is not None and upper_boundaries is not None + return lower_boundaries.is_strictly_bounded_above_by(upper_boundaries) + + +def _memory_bytes_for_coalesce(input_parts: Iterable[SingleOutputPartitionTask[PartitionT]]) -> int | None: + # Calculate memory request for task. + size_bytes_per_task = [task.partition_metadata().size_bytes for task in input_parts] + non_null_size_bytes_per_task = [size for size in size_bytes_per_task if size is not None] + non_null_size_bytes = sum(non_null_size_bytes_per_task) + if len(size_bytes_per_task) == len(non_null_size_bytes_per_task): + # If all task size bytes are non-null, directly use the non-null size bytes sum. + size_bytes = non_null_size_bytes + elif non_null_size_bytes_per_task: + # If some are null, calculate the non-null mean and assume that null task size bytes + # have that size. + mean_size = math.ceil(non_null_size_bytes / len(non_null_size_bytes_per_task)) + size_bytes = non_null_size_bytes + mean_size * (len(size_bytes_per_task) - len(non_null_size_bytes_per_task)) + else: + # If all null, set to null. + size_bytes = None + return size_bytes + + +def sort_merge_join_aligned_boundaries( + left_plan: InProgressPhysicalPlan[PartitionT], + right_plan: InProgressPhysicalPlan[PartitionT], + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + how: JoinType, + num_partitions: int, + left_is_larger: bool, +) -> InProgressPhysicalPlan[PartitionT]: + """ + Sort-merge join the partitions from `left_plan` and `right_plan` together. + + This assumes that both `left_plan` and `right_plan` need to be sorted, and will be sorted using the same + partitioning boundaries. + """ + # This algorithm proceeds in the following phases: + # 1. Sort both sides of the join. + # a. Fully materialize left and right child plans. + # b. Sample all partitions from both sides of the join. + # c. Create partitioning boundaries from global samples. + # d. Sort each side of join using global partitioning boundaries. + # 2. Merge-join the now-sorted sides of the join. + descending = [False] * len(left_on) + # First, materialize the left and right child plans. + left_source_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() + right_source_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() + stage_id_children = next(stage_id_counter) + for child, source_materializations in [ + (left_plan, left_source_materializations), + (right_plan, right_source_materializations), + ]: + for step in child: + if isinstance(step, PartitionTaskBuilder): + step = step.finalize_partition_task_single_output(stage_id=stage_id_children) + source_materializations.append(step) + yield step + + # Sample all partitions (to be used for calculating sort partitioning boundaries). + left_sample_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() + right_sample_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() + stage_id_sampling = next(stage_id_counter) + sample_size = get_context().daft_execution_config.sample_size_for_sort + + sample_size = get_context().daft_execution_config.sample_size_for_sort + for source_materializations, on, sample_materializations in [ + (left_source_materializations, left_on, left_sample_materializations), + (right_source_materializations, right_on, right_sample_materializations), + ]: + for source in source_materializations: + while not source.done(): + logger.debug("sort blocked on completion of source: %s", source) + yield None + + sample = ( + PartitionTaskBuilder[PartitionT]( + inputs=[source.partition()], + partial_metadatas=None, + ) + .add_instruction( + instruction=execution_step.Sample(sort_by=on, size=sample_size), + ) + # Rename sample columns so they align with sort_by_left naming, so we can reduce to combined quantiles below. + # NOTE: This instruction will be a no-op for the left side of the sort. + .add_instruction( + instruction=execution_step.Project( + projection=ExpressionsProjection( + [ + e.alias(left_name) + for e, left_name in zip(on.to_column_expressions(), [e.name() for e in left_on]) + ] + ), + ) + ) + .finalize_partition_task_single_output(stage_id=stage_id_sampling) + ) + + sample_materializations.append(sample) + yield sample + + # Wait for samples from both child plans to materialize. + for sample_materializations in (left_sample_materializations, right_sample_materializations): + while any(not _.done() for _ in sample_materializations): + logger.debug("sort blocked on completion of all samples: %s", sample_materializations) + yield None + + stage_id_reduce = next(stage_id_counter) + + # Reduce the samples from both child plans to get combined sort partitioning boundaries. + left_boundaries = ( + PartitionTaskBuilder[PartitionT]( + inputs=[ + sample.partition() + for sample in itertools.chain( + consume_deque(left_sample_materializations), consume_deque(right_sample_materializations) + ) + ], + partial_metadatas=None, + ) + .add_instruction( + execution_step.ReduceToQuantiles( + num_quantiles=num_partitions, + sort_by=left_on, + descending=descending, + ), + ) + .finalize_partition_task_single_output(stage_id=stage_id_reduce) + ) + yield left_boundaries + + # Wait for boundaries to materialize. + while not left_boundaries.done(): + logger.debug("sort blocked on completion of boundary partition: %s", left_boundaries) + yield None + + # Project boundaries back to the right-side column names. + # TODO(Clark): Refactor execution model to be able to fuse this with downstream sorting. + right_boundaries = ( + PartitionTaskBuilder[PartitionT]( + inputs=[left_boundaries.partition()], + partial_metadatas=None, + ) + # Rename quantile columns so their original naming is restored, so we can sort each child with their native expression. + .add_instruction( + instruction=execution_step.Project( + projection=ExpressionsProjection( + [ + e.alias(right_name) + for e, right_name in zip(left_on.to_column_expressions(), [e.name() for e in right_on]) + ] + ), + ) + ).finalize_partition_task_single_output(stage_id=stage_id_reduce) + ) + yield right_boundaries + + # Wait for right-side boundaries to materialize. + while not right_boundaries.done(): + logger.debug("sort blocked on completion of boundary partition: %s", right_boundaries) + yield None + + # Sort both children using the combined boundaries. + sorted_plans: list[InProgressPhysicalPlan[PartitionT]] = [] + for on, source_materializations, boundaries in [ + (left_on, left_source_materializations, left_boundaries), + (right_on, right_source_materializations, right_boundaries), + ]: + # NOTE: We need to give reduce() an iter(list), since giving it a generator would result in lazy + # binding in this loop. + range_fanout_plan = [ + PartitionTaskBuilder[PartitionT]( + inputs=[boundaries.partition(), source.partition()], + partial_metadatas=[boundaries.partition_metadata(), source.partition_metadata()], + resource_request=ResourceRequest( + memory_bytes=source.partition_metadata().size_bytes, + ), + ).add_instruction( + instruction=execution_step.FanoutRange[PartitionT]( + _num_outputs=num_partitions, + sort_by=on, + descending=descending, + ), + ) + for source in consume_deque(source_materializations) + ] + + # Execute a sorting reduce on it. + per_partition_bounds = _to_per_partition_bounds(boundaries.vpartition(), num_partitions) + sorted_plans.append( + reduce( + fanout_plan=iter(range_fanout_plan), + reduce_instructions=[ + execution_step.ReduceMergeAndSort( + sort_by=on, + descending=descending, + bounds=per_part_boundaries, + ) + for per_part_boundaries in per_partition_bounds + ], + ) + ) + + left_sorted_plan, right_sorted_plan = sorted_plans + + # Merge-join the two sorted sides of the join. + yield from merge_join_sorted(left_sorted_plan, right_sorted_plan, left_on, right_on, how, left_is_larger) + + +def _to_per_partition_bounds(boundaries: MicroPartition, num_partitions: int) -> list[MicroPartition]: + boundaries_dict = boundaries.to_pydict() + return [ + MicroPartition.from_pydict( + { + col_name: [ + pivots[i - 1] if i > 0 and i - 1 < len(pivots) else None, + pivots[i] if i < len(pivots) else None, + ] + for col_name, pivots in boundaries_dict.items() + } + ) + for i in range(num_partitions) + ] + + def concat( top_plan: InProgressPhysicalPlan[PartitionT], bottom_plan: InProgressPhysicalPlan[PartitionT] ) -> InProgressPhysicalPlan[PartitionT]: @@ -699,7 +1324,7 @@ def coalesce( def reduce( fanout_plan: InProgressPhysicalPlan[PartitionT], - reduce_instruction: ReduceInstruction, + reduce_instructions: ReduceInstruction | list[ReduceInstruction], ) -> InProgressPhysicalPlan[PartitionT]: """Reduce the result of fanout_plan. @@ -728,6 +1353,10 @@ def reduce( inputs_to_reduce = [deque(_.partitions()) for _ in materializations] metadatas = [deque(_.partition_metadatas()) for _ in materializations] del materializations + if not isinstance(reduce_instructions, list): + reduce_instructions = [reduce_instructions] * len(inputs_to_reduce[0]) + reduce_instructions_ = deque(reduce_instructions) + del reduce_instructions # Yield all the reduces in order. while len(inputs_to_reduce[0]) > 0: @@ -739,7 +1368,7 @@ def reduce( resource_request=ResourceRequest( memory_bytes=sum(metadata.size_bytes for metadata in metadata_batch), ), - ).add_instruction(reduce_instruction) + ).add_instruction(reduce_instructions_.popleft()) def sort( @@ -829,14 +1458,19 @@ def sort( ) for source in consume_deque(source_materializations) ) + per_partition_bounds = _to_per_partition_bounds(boundaries.vpartition(), num_partitions) # Execute a sorting reduce on it. yield from reduce( fanout_plan=range_fanout_plan, - reduce_instruction=execution_step.ReduceMergeAndSort( - sort_by=sort_by, - descending=descending, - ), + reduce_instructions=[ + execution_step.ReduceMergeAndSort( + sort_by=sort_by, + descending=descending, + bounds=per_part_boundaries, + ) + for per_part_boundaries in per_partition_bounds + ], ) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 8c90a1122a..25db7bec7a 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -246,6 +246,48 @@ def hash_join( ) +def merge_join_sorted( + input: physical_plan.InProgressPhysicalPlan[PartitionT], + right: physical_plan.InProgressPhysicalPlan[PartitionT], + left_on: list[PyExpr], + right_on: list[PyExpr], + join_type: JoinType, + left_is_larger: bool, +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) + right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) + return physical_plan.merge_join_sorted( + left_plan=input, + right_plan=right, + left_on=left_on_expr_proj, + right_on=right_on_expr_proj, + how=join_type, + left_is_larger=left_is_larger, + ) + + +def sort_merge_join_aligned_boundaries( + input: physical_plan.InProgressPhysicalPlan[PartitionT], + right: physical_plan.InProgressPhysicalPlan[PartitionT], + left_on: list[PyExpr], + right_on: list[PyExpr], + join_type: JoinType, + num_partitions: int, + left_is_larger: bool, +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) + right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) + return physical_plan.sort_merge_join_aligned_boundaries( + left_plan=input, + right_plan=right, + left_on=left_on_expr_proj, + right_on=right_on_expr_proj, + how=join_type, + num_partitions=num_partitions, + left_is_larger=left_is_larger, + ) + + def broadcast_join( broadcaster: physical_plan.InProgressPhysicalPlan[PartitionT], receiver: physical_plan.InProgressPhysicalPlan[PartitionT], diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 7fa2def97a..f6f9b8359a 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -9,6 +9,7 @@ FileFormatConfig, FileInfos, IOConfig, + JoinStrategy, JoinType, ) from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder @@ -187,6 +188,7 @@ def join( # type: ignore[override] left_on: list[Expression], right_on: list[Expression], how: JoinType = JoinType.Inner, + strategy: JoinStrategy | None = None, ) -> LogicalPlanBuilder: if how == JoinType.Left: raise NotImplementedError("Left join not implemented.") @@ -198,6 +200,7 @@ def join( # type: ignore[override] [expr._expr for expr in left_on], [expr._expr for expr in right_on], how, + strategy, ) return LogicalPlanBuilder(builder) else: diff --git a/daft/plan_scheduler/physical_plan_scheduler.py b/daft/plan_scheduler/physical_plan_scheduler.py index 0668597b19..a45d208828 100644 --- a/daft/plan_scheduler/physical_plan_scheduler.py +++ b/daft/plan_scheduler/physical_plan_scheduler.py @@ -1,5 +1,6 @@ from __future__ import annotations +from daft.daft import PartitionSpec from daft.daft import PhysicalPlanScheduler as _PhysicalPlanScheduler from daft.execution import physical_plan from daft.runners.partitioning import PartitionT @@ -16,6 +17,9 @@ def __init__(self, scheduler: _PhysicalPlanScheduler): def num_partitions(self) -> int: return self._scheduler.num_partitions() + def partition_spec(self) -> PartitionSpec: + return self._scheduler.partition_spec() + def to_partition_tasks( self, psets: dict[str, list[PartitionT]], is_ray_runner: bool ) -> physical_plan.MaterializedPhysicalPlan: diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 45e86c0d54..e43d6eaee1 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -10,6 +10,7 @@ import pyarrow as pa from daft.datatype import TimeUnit +from daft.expressions.expressions import Expression from daft.logical.schema import Schema from daft.table import MicroPartition @@ -74,20 +75,97 @@ class TableParseParquetOptions: class PartialPartitionMetadata: num_rows: None | int size_bytes: None | int + boundaries: None | Boundaries = None @dataclass(frozen=True) class PartitionMetadata(PartialPartitionMetadata): num_rows: int size_bytes: int | None + boundaries: Boundaries | None = None @classmethod def from_table(cls, table: MicroPartition) -> PartitionMetadata: return PartitionMetadata( num_rows=len(table), size_bytes=table.size_bytes(), + boundaries=None, ) + def merge_with_partial(self, partial_metadata: PartialPartitionMetadata) -> PartitionMetadata: + num_rows = self.num_rows + size_bytes = self.size_bytes + boundaries = self.boundaries + if boundaries is None: + boundaries = partial_metadata.boundaries + return PartitionMetadata(num_rows, size_bytes, boundaries) + + def downcast_to_partial(self) -> PartialPartitionMetadata: + return PartialPartitionMetadata(self.num_rows, self.size_bytes, self.boundaries) + + +def _is_bound_null(bound_row: list[Any | None]) -> bool: + return all(bound is None for bound in bound_row) + + +# TODO(Clark): Port this to the Rust side. +@dataclass(frozen=True) +class Boundaries: + sort_by: list[Expression] + bounds: MicroPartition + + def __post_init__(self): + assert len(self.sort_by) > 0 + assert len(self.bounds) == 2 + assert self.bounds.column_names() == [e.name() for e in self.sort_by] + + def intersects(self, other: Boundaries) -> bool: + if self.is_trivial_bounds() or other.is_trivial_bounds(): + return True + self_bounds = self.bounds.to_pylist() + other_bounds = other.bounds.to_pylist() + self_lower = list(self_bounds[0].values()) + self_upper = list(self_bounds[1].values()) + other_lower = list(other_bounds[0].values()) + other_upper = list(other_bounds[1].values()) + if _is_bound_null(self_lower): + return _is_bound_null(other_lower) or other_lower <= self_upper + if _is_bound_null(other_lower): + return self_lower <= other_upper + if _is_bound_null(self_upper): + return _is_bound_null(other_upper) or other_upper >= self_lower + if _is_bound_null(other_upper): + return self_upper >= other_lower + return (self_lower <= other_lower and self_upper >= other_lower) or ( + self_lower > other_lower and other_upper >= self_lower + ) + + def is_disjointly_bounded_above_by(self, other: Boundaries) -> bool: + # Check that upper of self is less than lower of other. + self_upper = list(self.bounds.to_pylist()[1].values()) + if _is_bound_null(self_upper): + return False + other_lower = list(other.bounds.to_pylist()[0].values()) + if _is_bound_null(other_lower): + return False + return self_upper < other_lower + + def is_trivial_bounds(self) -> bool: + bounds = self.bounds.to_pylist() + lower = list(bounds[0].values()) + upper = list(bounds[1].values()) + return _is_bound_null(lower) and _is_bound_null(upper) + + def is_strictly_bounded_above_by(self, other: Boundaries) -> bool: + # Check that upper of self is less than upper of other. + self_upper = list(self.bounds.to_pylist()[1].values()) + if _is_bound_null(self_upper): + return False + other_upper = list(other.bounds.to_pylist()[1].values()) + if _is_bound_null(other_upper): + return True + return self_upper < other_upper + PartitionT = TypeVar("PartitionT") diff --git a/daft/table/micropartition.py b/daft/table/micropartition.py index 94d7ebfad7..5782b4810a 100644 --- a/daft/table/micropartition.py +++ b/daft/table/micropartition.py @@ -230,7 +230,7 @@ def explode(self, columns: ExpressionsProjection) -> MicroPartition: to_explode_pyexprs = [e._expr for e in columns] return MicroPartition._from_pymicropartition(self._micropartition.explode(to_explode_pyexprs)) - def join( + def hash_join( self, right: MicroPartition, left_on: ExpressionsProjection, @@ -251,7 +251,34 @@ def join( right_exprs = [e._expr for e in right_on] return MicroPartition._from_pymicropartition( - self._micropartition.join(right._micropartition, left_on=left_exprs, right_on=right_exprs) + self._micropartition.hash_join(right._micropartition, left_on=left_exprs, right_on=right_exprs) + ) + + def sort_merge_join( + self, + right: MicroPartition, + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + how: JoinType = JoinType.Inner, + is_sorted: bool = False, + ) -> MicroPartition: + if how != JoinType.Inner: + raise NotImplementedError("TODO: [RUST] Implement Other Join types") + if len(left_on) != len(right_on): + raise ValueError( + f"Mismatch of number of join keys, left_on: {len(left_on)}, right_on: {len(right_on)}\nleft_on {left_on}\nright_on {right_on}" + ) + + if not isinstance(right, MicroPartition): + raise TypeError(f"Expected a MicroPartition for `right` in join but got {type(right)}") + + left_exprs = [e._expr for e in left_on] + right_exprs = [e._expr for e in right_on] + + return MicroPartition._from_pymicropartition( + self._micropartition.sort_merge_join( + right._micropartition, left_on=left_exprs, right_on=right_exprs, is_sorted=is_sorted + ) ) def partition_by_hash(self, exprs: ExpressionsProjection, num_partitions: int) -> list[MicroPartition]: diff --git a/daft/table/table.py b/daft/table/table.py index 6f7e0e972c..787e56aa0a 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -314,7 +314,7 @@ def explode(self, columns: ExpressionsProjection) -> Table: to_explode_pyexprs = [e._expr for e in columns] return Table._from_pytable(self._table.explode(to_explode_pyexprs)) - def join( + def hash_join( self, right: Table, left_on: ExpressionsProjection, @@ -334,7 +334,32 @@ def join( left_exprs = [e._expr for e in left_on] right_exprs = [e._expr for e in right_on] - return Table._from_pytable(self._table.join(right._table, left_on=left_exprs, right_on=right_exprs)) + return Table._from_pytable(self._table.hash_join(right._table, left_on=left_exprs, right_on=right_exprs)) + + def sort_merge_join( + self, + right: Table, + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + how: JoinType = JoinType.Inner, + is_sorted: bool = False, + ) -> Table: + if how != JoinType.Inner: + raise NotImplementedError("TODO: [RUST] Implement Other Join types") + if len(left_on) != len(right_on): + raise ValueError( + f"Mismatch of number of join keys, left_on: {len(left_on)}, right_on: {len(right_on)}\nleft_on {left_on}\nright_on {right_on}" + ) + + if not isinstance(right, Table): + raise TypeError(f"Expected a Table for `right` in join but got {type(right)}") + + left_exprs = [e._expr for e in left_on] + right_exprs = [e._expr for e in right_on] + + return Table._from_pytable( + self._table.sort_merge_join(right._table, left_on=left_exprs, right_on=right_exprs, is_sorted=is_sorted) + ) def partition_by_hash(self, exprs: ExpressionsProjection, num_partitions: int) -> list[Table]: if not isinstance(num_partitions, int): diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 2d312f45f5..59141dba0d 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -25,6 +25,7 @@ pub struct DaftExecutionConfig { pub merge_scan_tasks_min_size_bytes: usize, pub merge_scan_tasks_max_size_bytes: usize, pub broadcast_join_size_bytes_threshold: usize, + pub sort_merge_join_sort_with_aligned_boundaries: bool, pub sample_size_for_sort: usize, pub num_preview_rows: usize, pub parquet_target_filesize: usize, @@ -40,6 +41,7 @@ impl Default for DaftExecutionConfig { merge_scan_tasks_min_size_bytes: 64 * 1024 * 1024, // 64MB merge_scan_tasks_max_size_bytes: 512 * 1024 * 1024, // 512MB broadcast_join_size_bytes_threshold: 10 * 1024 * 1024, // 10 MiB + sort_merge_join_sort_with_aligned_boundaries: false, sample_size_for_sort: 20, num_preview_rows: 8, parquet_target_filesize: 512 * 1024 * 1024, // 512MB diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 0a18523fc6..1428bbe3b5 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -80,6 +80,7 @@ impl PyDaftExecutionConfig { merge_scan_tasks_min_size_bytes: Option, merge_scan_tasks_max_size_bytes: Option, broadcast_join_size_bytes_threshold: Option, + sort_merge_join_sort_with_aligned_boundaries: Option, sample_size_for_sort: Option, num_preview_rows: Option, parquet_target_filesize: Option, @@ -99,6 +100,12 @@ impl PyDaftExecutionConfig { if let Some(broadcast_join_size_bytes_threshold) = broadcast_join_size_bytes_threshold { config.broadcast_join_size_bytes_threshold = broadcast_join_size_bytes_threshold; } + if let Some(sort_merge_join_sort_with_aligned_boundaries) = + sort_merge_join_sort_with_aligned_boundaries + { + config.sort_merge_join_sort_with_aligned_boundaries = + sort_merge_join_sort_with_aligned_boundaries; + } if let Some(sample_size_for_sort) = sample_size_for_sort { config.sample_size_for_sort = sample_size_for_sort; } @@ -141,6 +148,11 @@ impl PyDaftExecutionConfig { Ok(self.config.broadcast_join_size_bytes_threshold) } + #[getter] + fn get_sort_merge_join_sort_with_aligned_boundaries(&self) -> PyResult { + Ok(self.config.sort_merge_join_sort_with_aligned_boundaries) + } + #[getter] fn get_sample_size_for_sort(&self) -> PyResult { Ok(self.config.sample_size_for_sort) diff --git a/src/daft-core/src/kernels/search_sorted.rs b/src/daft-core/src/kernels/search_sorted.rs index ebe9ce2d45..d46d23f8a7 100644 --- a/src/daft-core/src/kernels/search_sorted.rs +++ b/src/daft-core/src/kernels/search_sorted.rs @@ -324,6 +324,39 @@ pub fn build_compare_with_nulls( } } +/// Compare the values at two arbitrary indices in two arrays. +pub type DynPartialComparator = Box Option + Send + Sync>; + +pub fn build_partial_compare_with_nulls( + left: &dyn Array, + right: &dyn Array, + reversed: bool, +) -> Result { + let comparator = build_compare_with_nan(left, right)?; + let left_is_valid = build_is_valid(left); + let right_is_valid = build_is_valid(right); + + if reversed { + Ok(Box::new(move |i: usize, j: usize| { + match (left_is_valid(i), right_is_valid(j)) { + (true, true) => Some(comparator(i, j).reverse()), + (false, true) => Some(Ordering::Less), + (true, false) => Some(Ordering::Greater), + (false, false) => None, + } + })) + } else { + Ok(Box::new(move |i: usize, j: usize| { + match (left_is_valid(i), right_is_valid(j)) { + (true, true) => Some(comparator(i, j)), + (false, true) => Some(Ordering::Greater), + (true, false) => Some(Ordering::Less), + (false, false) => None, + } + })) + } +} + pub fn search_sorted_multi_array( sorted_arrays: &Vec<&dyn Array>, key_arrays: &Vec<&dyn Array>, diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index 86338a2380..46690b6717 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -9,8 +9,8 @@ use crate::micropartition::MicroPartition; use daft_stats::TruthValue; impl MicroPartition { - pub fn join(&self, right: &Self, left_on: &[Expr], right_on: &[Expr]) -> DaftResult { - let io_stats = IOStatsContext::new("MicroPartition::join"); + pub fn hash_join(&self, right: &Self, left_on: &[Expr], right_on: &[Expr]) -> DaftResult { + let io_stats = IOStatsContext::new("MicroPartition::hash_join"); let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?; if self.len() == 0 || right.len() == 0 { @@ -48,7 +48,63 @@ impl MicroPartition { match (lt.as_slice(), rt.as_slice()) { ([], _) | (_, []) => Ok(Self::empty(Some(join_schema.into()))), ([lt], [rt]) => { - let joined_table = lt.join(rt, left_on, right_on)?; + let joined_table = lt.hash_join(rt, left_on, right_on)?; + Ok(MicroPartition::new_loaded( + join_schema.into(), + vec![joined_table].into(), + None, + )) + } + _ => unreachable!(), + } + } + + pub fn sort_merge_join( + &self, + right: &Self, + left_on: &[Expr], + right_on: &[Expr], + is_sorted: bool, + ) -> DaftResult { + let io_stats = IOStatsContext::new("MicroPartition::sort_merge_join"); + let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?; + + if self.len() == 0 || right.len() == 0 { + return Ok(Self::empty(Some(join_schema.into()))); + } + + let tv = match (&self.statistics, &right.statistics) { + (_, None) => TruthValue::Maybe, + (None, _) => TruthValue::Maybe, + (Some(l), Some(r)) => { + let l_eval_stats = l.eval_expression_list(left_on, &self.schema)?; + let r_eval_stats = r.eval_expression_list(right_on, &right.schema)?; + let mut curr_tv = TruthValue::Maybe; + for (lc, rc) in l_eval_stats + .columns + .values() + .zip(r_eval_stats.columns.values()) + { + if let TruthValue::False = lc.equal(rc)?.to_truth_value() { + curr_tv = TruthValue::False; + break; + } + } + curr_tv + } + }; + if let TruthValue::False = tv { + return Ok(Self::empty(Some(join_schema.into()))); + } + + // TODO(Clark): Elide concatenations where possible by doing a chunk-aware local table join. + let lt = self.concat_or_get(io_stats.clone())?; + let rt = right.concat_or_get(io_stats)?; + + match (lt.as_slice(), rt.as_slice()) { + ([], _) | (_, []) => Ok(Self::empty(Some(join_schema.into()))), + ([lt], [rt]) => { + let joined_table = lt.sort_merge_join(rt, left_on, right_on, is_sorted)?; Ok(MicroPartition::new_loaded( join_schema.into(), vec![joined_table].into(), diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index a1e1f0aae5..7c8800035c 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -222,7 +222,7 @@ impl PyMicroPartition { }) } - pub fn join( + pub fn hash_join( &self, py: Python, right: &Self, @@ -234,7 +234,30 @@ impl PyMicroPartition { py.allow_threads(|| { Ok(self .inner - .join(&right.inner, left_exprs.as_slice(), right_exprs.as_slice())? + .hash_join(&right.inner, left_exprs.as_slice(), right_exprs.as_slice())? + .into()) + }) + } + + pub fn sort_merge_join( + &self, + py: Python, + right: &Self, + left_on: Vec, + right_on: Vec, + is_sorted: bool, + ) -> PyResult { + let left_exprs: Vec = left_on.into_iter().map(|e| e.into()).collect(); + let right_exprs: Vec = right_on.into_iter().map(|e| e.into()).collect(); + py.allow_threads(|| { + Ok(self + .inner + .sort_merge_join( + &right.inner, + left_exprs.as_slice(), + right_exprs.as_slice(), + is_sorted, + )? .into()) }) } diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 4fd2b5ec06..ec28cf5440 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -10,7 +10,7 @@ use crate::{ ExternalInfo as ExternalSourceInfo, FileInfos as InputFileInfos, LegacyExternalInfo, SourceInfo, }, - JoinType, PartitionScheme, PhysicalPlanScheduler, ResourceRequest, + JoinStrategy, JoinType, PartitionScheme, PhysicalPlanScheduler, ResourceRequest, }; use common_error::{DaftError, DaftResult}; use common_io_config::IOConfig; @@ -212,6 +212,7 @@ impl LogicalPlanBuilder { left_on: Vec, right_on: Vec, join_type: JoinType, + join_strategy: Option, ) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Join::try_new( self.plan.clone(), @@ -219,6 +220,7 @@ impl LogicalPlanBuilder { left_on, right_on, join_type, + join_strategy, )? .into(); Ok(logical_plan.into()) @@ -416,6 +418,7 @@ impl PyLogicalPlanBuilder { left_on: Vec, right_on: Vec, join_type: JoinType, + join_strategy: Option, ) -> PyResult { let left_on = left_on .iter() @@ -427,7 +430,7 @@ impl PyLogicalPlanBuilder { .collect::>(); Ok(self .builder - .join(&right.builder, left_on, right_on, join_type)? + .join(&right.builder, left_on, right_on, join_type, join_strategy)? .into()) } diff --git a/src/daft-plan/src/join.rs b/src/daft-plan/src/join.rs index 725d8a000f..0dbf42a9eb 100644 --- a/src/daft-plan/src/join.rs +++ b/src/daft-plan/src/join.rs @@ -61,7 +61,7 @@ impl FromStr for JoinType { "left" => Ok(Left), "right" => Ok(Right), _ => Err(DaftError::TypeError(format!( - "Join type {} is not supported; only the following modes are supported: {:?}", + "Join type {} is not supported; only the following types are supported: {:?}", join_type, JoinType::iterator().as_slice() ))), @@ -75,3 +75,65 @@ impl Display for JoinType { write!(f, "{:?}", self) } } + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] +pub enum JoinStrategy { + Hash, + SortMerge, + Broadcast, +} + +#[cfg(feature = "python")] +#[pymethods] +impl JoinStrategy { + /// Create a JoinStrategy from its string representation. + /// + /// Args: + /// join_strategy: String representation of the join strategy, e.g. "hash", "sort_merge", or "broadcast". + #[staticmethod] + pub fn from_join_strategy_str(join_strategy: &str) -> PyResult { + Self::from_str(join_strategy).map_err(|e| PyValueError::new_err(e.to_string())) + } + + pub fn __str__(&self) -> PyResult { + Ok(self.to_string()) + } +} + +impl_bincode_py_state_serialization!(JoinStrategy); + +impl JoinStrategy { + pub fn iterator() -> std::slice::Iter<'static, JoinStrategy> { + use JoinStrategy::*; + + static JOIN_STRATEGIES: [JoinStrategy; 3] = [Hash, SortMerge, Broadcast]; + JOIN_STRATEGIES.iter() + } +} + +impl FromStr for JoinStrategy { + type Err = DaftError; + + fn from_str(join_strategy: &str) -> DaftResult { + use JoinStrategy::*; + + match join_strategy { + "hash" => Ok(Hash), + "sort_merge" => Ok(SortMerge), + "broadcast" => Ok(Broadcast), + _ => Err(DaftError::TypeError(format!( + "Join strategy {} is not supported; only the following strategies are supported: {:?}", + join_strategy, + JoinStrategy::iterator().as_slice() + ))), + } + } +} + +impl Display for JoinStrategy { + fn fmt(&self, f: &mut Formatter) -> Result { + // Leverage Debug trait implementation, which will already return the enum variant as a string. + write!(f, "{:?}", self) + } +} diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index ec8b4866a3..494e38a83c 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -25,7 +25,7 @@ use daft_scan::{ }, storage_config::{NativeStorageConfig, PyStorageConfig}, }; -pub use join::JoinType; +pub use join::{JoinStrategy, JoinType}; pub use logical_plan::LogicalPlan; pub use partitioning::{PartitionScheme, PartitionSpec}; pub use physical_plan::PhysicalPlanScheduler; @@ -48,6 +48,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 60f4d9d8f3..8f6af21526 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -10,7 +10,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{self, CreationSnafu}, - JoinType, LogicalPlan, + JoinStrategy, JoinType, LogicalPlan, }; #[derive(Clone, Debug, PartialEq, Eq)] @@ -22,6 +22,7 @@ pub struct Join { pub left_on: Vec, pub right_on: Vec, pub join_type: JoinType, + pub join_strategy: Option, pub output_schema: SchemaRef, // Joins may rename columns from the right input; this struct tracks those renames. @@ -36,6 +37,7 @@ impl std::hash::Hash for Join { std::hash::Hash::hash(&self.left_on, state); std::hash::Hash::hash(&self.right_on, state); std::hash::Hash::hash(&self.join_type, state); + std::hash::Hash::hash(&self.join_strategy, state); std::hash::Hash::hash(&self.output_schema, state); state.write_u64(hash_index_map(&self.right_input_mapping)) } @@ -48,6 +50,7 @@ impl Join { left_on: Vec, right_on: Vec, join_type: JoinType, + join_strategy: Option, ) -> logical_plan::Result { for (on_exprs, schema) in [(&left_on, left.schema()), (&right_on, right.schema())] { let on_fields = on_exprs @@ -101,6 +104,7 @@ impl Join { left_on, right_on, join_type, + join_strategy, output_schema, right_input_mapping, }) @@ -109,6 +113,11 @@ impl Join { pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Join: Type = {}", self.join_type)); + res.push(format!( + "Strategy = {}", + self.join_strategy + .map_or_else(|| "Auto".to_string(), |s| s.to_string()) + )); if !self.left_on.is_empty() && !self.right_on.is_empty() && self.left_on == self.right_on { res.push(format!( "On = {}", diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index f1ef10a3c9..a74d025bcb 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -162,7 +162,7 @@ impl LogicalPlan { [input1, input2] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()), - Self::Join(Join { left_on, right_on, join_type, .. }) => Self::Join(Join::try_new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type).unwrap()), + Self::Join(Join { left_on, right_on, join_type, join_strategy, .. }) => Self::Join(Join::try_new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type, *join_strategy).unwrap()), _ => panic!("Logical op {} has one input, but got two", self), }, _ => panic!("Logical ops should never have more than 2 inputs, but got: {}", children.len()) diff --git a/src/daft-plan/src/optimization/rules/push_down_filter.rs b/src/daft-plan/src/optimization/rules/push_down_filter.rs index 9ef921e33f..b80a0f0ab0 100644 --- a/src/daft-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/optimization/rules/push_down_filter.rs @@ -537,11 +537,12 @@ mod tests { vec![col("b")], vec![col("b")], JoinType::Inner, + None, )? .filter(col("a").lt(&lit(2)))? .build(); let expected = "\ - Join: Type = Inner, On = col(b), Output schema = a (Int64), b (Utf8), c (Float64)\ + Join: Type = Inner, Strategy = Auto, On = col(b), Output schema = a (Int64), b (Utf8), c (Float64)\ \n Filter: col(a) < lit(2)\ \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig { buffer_size: None, chunk_size: None }), Storage config = Native(NativeStorageConfig { io_config: None, multithreaded_io: true }), Output schema = a (Int64), b (Utf8)\ \n Source: Json, File paths = [/foo], File schema = b (Utf8), c (Float64), Format-specific config = Json(JsonSourceConfig { buffer_size: None, chunk_size: None }), Storage config = Native(NativeStorageConfig { io_config: None, multithreaded_io: true }), Output schema = b (Utf8), c (Float64)"; @@ -564,11 +565,12 @@ mod tests { vec![col("b")], vec![col("b")], JoinType::Inner, + None, )? .filter(col("c").lt(&lit(2.0)))? .build(); let expected = "\ - Join: Type = Inner, On = col(b), Output schema = a (Int64), b (Utf8), c (Float64)\ + Join: Type = Inner, Strategy = Auto, On = col(b), Output schema = a (Int64), b (Utf8), c (Float64)\ \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig { buffer_size: None, chunk_size: None }), Storage config = Native(NativeStorageConfig { io_config: None, multithreaded_io: true }), Output schema = a (Int64), b (Utf8)\ \n Filter: col(c) < lit(2.0)\ \n Source: Json, File paths = [/foo], File schema = b (Utf8), c (Float64), Format-specific config = Json(JsonSourceConfig { buffer_size: None, chunk_size: None }), Storage config = Native(NativeStorageConfig { io_config: None, multithreaded_io: true }), Output schema = b (Utf8), c (Float64)"; @@ -589,11 +591,12 @@ mod tests { vec![col("b")], vec![col("b")], JoinType::Inner, + None, )? .filter(col("b").lt(&lit(2)))? .build(); let expected = "\ - Join: Type = Inner, On = col(b), Output schema = a (Int64), b (Int64), c (Float64)\ + Join: Type = Inner, Strategy = Auto, On = col(b), Output schema = a (Int64), b (Int64), c (Float64)\ \n Filter: col(b) < lit(2)\ \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Int64), c (Float64), Format-specific config = Json(JsonSourceConfig { buffer_size: None, chunk_size: None }), Storage config = Native(NativeStorageConfig { io_config: None, multithreaded_io: true }), Output schema = a (Int64), b (Int64), c (Float64)\ \n Filter: col(b) < lit(2)\ diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index 214bcdb3df..08ba46ea34 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -19,6 +19,7 @@ mod reduce; mod sample; mod scan; mod sort; +mod sort_merge_join; mod split; pub use agg::Aggregate; @@ -42,4 +43,5 @@ pub use reduce::ReduceMerge; pub use sample::Sample; pub use scan::TabularScan; pub use sort::Sort; +pub use sort_merge_join::SortMergeJoin; pub use split::Split; diff --git a/src/daft-plan/src/physical_ops/sort_merge_join.rs b/src/daft-plan/src/physical_ops/sort_merge_join.rs new file mode 100644 index 0000000000..ec0c381c4d --- /dev/null +++ b/src/daft-plan/src/physical_ops/sort_merge_join.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use daft_dsl::Expr; + +use crate::{physical_plan::PhysicalPlan, JoinType}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SortMergeJoin { + // Upstream node. + pub left: Arc, + pub right: Arc, + pub left_on: Vec, + pub right_on: Vec, + pub join_type: JoinType, + pub num_partitions: usize, + pub left_is_larger: bool, + pub needs_presort: bool, +} + +impl SortMergeJoin { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + left: Arc, + right: Arc, + left_on: Vec, + right_on: Vec, + join_type: JoinType, + num_partitions: usize, + left_is_larger: bool, + needs_presort: bool, + ) -> Self { + Self { + left, + right, + left_on, + right_on, + join_type, + num_partitions, + left_is_larger, + needs_presort, + } + } +} diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 21bf768a17..bd7b56c7ac 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -55,6 +55,7 @@ pub enum PhysicalPlan { Aggregate(Aggregate), Concat(Concat), HashJoin(HashJoin), + SortMergeJoin(SortMergeJoin), BroadcastJoin(BroadcastJoin), TabularWriteParquet(TabularWriteParquet), TabularWriteJson(TabularWriteJson), @@ -167,6 +168,20 @@ impl PhysicalPlan { Self::BroadcastJoin(BroadcastJoin { receiver: right, .. }) => right.partition_spec(), + Self::SortMergeJoin(SortMergeJoin { + left, + right, + left_on, + .. + }) => PartitionSpec::new_internal( + PartitionScheme::Range, + max( + left.partition_spec().num_partitions, + right.partition_spec().num_partitions, + ), + Some(left_on.clone()), + ) + .into(), Self::TabularWriteParquet(TabularWriteParquet { input, .. }) => input.partition_spec(), Self::TabularWriteCsv(TabularWriteCsv { input, .. }) => input.partition_spec(), Self::TabularWriteJson(TabularWriteJson { input, .. }) => input.partition_spec(), @@ -220,7 +235,8 @@ impl PhysicalPlan { receiver: right, .. }) - | Self::HashJoin(HashJoin { left, right, .. }) => { + | Self::HashJoin(HashJoin { left, right, .. }) + | Self::SortMergeJoin(SortMergeJoin { left, right, .. }) => { left.approximate_size_bytes().and_then(|left_size| { right .approximate_size_bytes() @@ -255,6 +271,9 @@ impl PhysicalPlanScheduler { pub fn num_partitions(&self) -> PyResult { self.plan.partition_spec().get_num_partitions() } + pub fn partition_spec(&self) -> PyResult { + Ok(self.plan.partition_spec().as_ref().clone()) + } /// Converts the contained physical plan into an iterator of executable partition tasks. pub fn to_partition_tasks( &self, @@ -719,6 +738,53 @@ impl PhysicalPlan { ))?; Ok(py_iter.into()) } + PhysicalPlan::SortMergeJoin(SortMergeJoin { + left, + right, + left_on, + right_on, + join_type, + num_partitions, + left_is_larger, + needs_presort, + }) => { + let left_iter = left.to_partition_tasks(py, psets, is_ray_runner)?; + let right_iter = right.to_partition_tasks(py, psets, is_ray_runner)?; + let left_on_pyexprs: Vec = left_on + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let right_on_pyexprs: Vec = right_on + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + // TODO(Clark): Elide sorting one side of the join if already range-partitioned, where we'd use that side's boundaries to sort the other side. + let py_iter = if *needs_presort { + py.import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "sort_merge_join_aligned_boundaries"))? + .call1(( + left_iter, + right_iter, + left_on_pyexprs, + right_on_pyexprs, + *join_type, + *num_partitions, + *left_is_larger, + ))? + } else { + py.import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "merge_join_sorted"))? + .call1(( + left_iter, + right_iter, + left_on_pyexprs, + right_on_pyexprs, + *join_type, + *left_is_larger, + ))? + }; + Ok(py_iter.into()) + } PhysicalPlan::BroadcastJoin(BroadcastJoin { broadcaster: left, receiver: right, diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 5a92d3248c..d6b06322bf 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -5,6 +5,7 @@ use std::{cmp::max, collections::HashMap}; use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use daft_core::count_mode::CountMode; +use daft_core::DataType; use daft_dsl::Expr; use daft_scan::file_format::FileFormatConfig; use daft_scan::ScanExternalInfo; @@ -19,7 +20,7 @@ use crate::logical_plan::LogicalPlan; use crate::physical_plan::PhysicalPlan; use crate::sink_info::{OutputFileInfo, SinkInfo}; use crate::source_info::{ExternalInfo as ExternalSourceInfo, LegacyExternalInfo, SourceInfo}; -use crate::{physical_ops::*, PartitionSpec}; +use crate::{physical_ops::*, JoinStrategy, PartitionSpec}; use crate::{FileFormat, PartitionScheme}; #[cfg(feature = "python")] @@ -511,6 +512,8 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe left_on, right_on, join_type, + join_strategy, + output_schema, .. }) => { let mut left_physical = plan(left, cfg.clone())?; @@ -519,25 +522,45 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe let left_pspec = left_physical.partition_spec(); let right_pspec = right_physical.partition_spec(); let num_partitions = max(left_pspec.num_partitions, right_pspec.num_partitions); - let new_left_pspec = Arc::new(PartitionSpec::new_internal( + let new_left_hash_pspec = Arc::new(PartitionSpec::new_internal( PartitionScheme::Hash, num_partitions, Some(left_on.clone()), )); - let new_right_pspec = Arc::new(PartitionSpec::new_internal( + let new_right_hash_pspec = Arc::new(PartitionSpec::new_internal( PartitionScheme::Hash, num_partitions, Some(right_on.clone()), )); - let is_left_partitioned = left_pspec == new_left_pspec; - let is_right_partitioned = right_pspec == new_right_pspec; + let is_left_hash_partitioned = left_pspec == new_left_hash_pspec; + let is_right_hash_partitioned = right_pspec == new_right_hash_pspec; - // If either the left or right side of the join are very small tables, perform a broadcast join with the - // entire smaller table broadcast to each of the partitions of the larger table. + // Left-side of join is considered to be sort-partitioned on the join key if it is sort-partitioned on a + // sequence of expressions that has the join key as a prefix. + let is_left_sort_partitioned = matches!(left_pspec.scheme, PartitionScheme::Range) + && left_pspec + .by + .as_ref() + .map(|e| { + e.len() >= left_on.len() + && e.iter().zip(left_on.iter()).all(|(e1, e2)| e1 == e2) + }) + .unwrap_or(false); + // Right-side of join is considered to be sort-partitioned on the join key if it is sort-partitioned on a + // sequence of expressions that has the join key as a prefix. + let is_right_sort_partitioned = matches!(right_pspec.scheme, PartitionScheme::Range) + && right_pspec + .by + .as_ref() + .map(|e| { + e.len() >= right_on.len() + && e.iter().zip(right_on.iter()).all(|(e1, e2)| e1 == e2) + }) + .unwrap_or(false); - // Ensure that the left side of the join is the smaller side. - let (smaller_size_bytes, do_swap) = match ( + // For broadcast joins, ensure that the left side of the join is the smaller side. + let (smaller_size_bytes, left_is_larger) = match ( left_physical.approximate_size_bytes(), right_physical.approximate_size_bytes(), ) { @@ -552,46 +575,120 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe (None, Some(right_size_bytes)) => (Some(right_size_bytes), true), (None, None) => (None, false), }; - let is_larger_partitioned = if do_swap { - is_left_partitioned + let is_larger_partitioned = if left_is_larger { + is_left_hash_partitioned || is_left_sort_partitioned } else { - is_right_partitioned + is_right_hash_partitioned || is_right_sort_partitioned }; - // If larger table is not already partitioned on the join key AND the smaller table is under broadcast size threshold, use broadcast join. - if !is_larger_partitioned && let Some(smaller_size_bytes) = smaller_size_bytes && smaller_size_bytes <= cfg.broadcast_join_size_bytes_threshold { - if do_swap { - // These will get swapped back when doing the actual local joins. - (left_physical, right_physical) = (right_physical, left_physical); + let join_strategy = join_strategy.unwrap_or_else(|| { + let is_primitive = |exprs: &Vec| exprs.iter().map(|e| e.name().unwrap()).all(|col| { + let dtype = &output_schema.get_field(col).unwrap().dtype; + dtype.is_integer() || dtype.is_floating() || matches!(dtype, DataType::Utf8 | DataType::Binary | DataType::Boolean) + }); + // If larger table is not already partitioned on the join key AND the smaller table is under broadcast size threshold, use broadcast join. + if !is_larger_partitioned && let Some(smaller_size_bytes) = smaller_size_bytes && smaller_size_bytes <= cfg.broadcast_join_size_bytes_threshold { + JoinStrategy::Broadcast + // Larger side of join is range-partitioned on the join column, so we use a sort-merge join. + // TODO(Clark): Support non-primitive dtypes for sort-merge join (e.g. temporal types). + // TODO(Clark): Also do a sort-merge join if a downstream op needs the table to be sorted on the join key. + // TODO(Clark): Look into defaulting to sort-merge join over hash join under more input partitioning setups. + } else if is_primitive(left_on) && is_primitive(right_on) && (is_left_sort_partitioned || is_right_sort_partitioned) + && (!is_larger_partitioned + || (left_is_larger && is_left_sort_partitioned + || !left_is_larger && is_right_sort_partitioned)) { + JoinStrategy::SortMerge + // Otherwise, use a hash join. + } else { + JoinStrategy::Hash + } + }); + match join_strategy { + JoinStrategy::Broadcast => { + // If either the left or right side of the join are very small tables, perform a broadcast join with the + // entire smaller table broadcast to each of the partitions of the larger table. + if left_is_larger { + // These will get swapped back when doing the actual local joins. + (left_physical, right_physical) = (right_physical, left_physical); + } + Ok(PhysicalPlan::BroadcastJoin(BroadcastJoin::new( + left_physical.into(), + right_physical.into(), + left_on.clone(), + right_on.clone(), + *join_type, + left_is_larger, + ))) + } + JoinStrategy::SortMerge => { + let needs_presort = if cfg.sort_merge_join_sort_with_aligned_boundaries { + // Use the special-purpose presorting that ensures join inputs are sorted with aligned + // boundaries, allowing for a more efficient downstream merge-join (~one-to-one zip). + !is_left_sort_partitioned || !is_right_sort_partitioned + } else { + // Manually insert presorting ops for each side of the join that needs it. + // Note that these merge-join inputs will most likely not have aligned boundaries, which could + // result in less efficient merge-joins (~all-to-all broadcast). + if !is_left_sort_partitioned { + left_physical = PhysicalPlan::Sort(Sort::new( + left_physical.into(), + left_on.clone(), + std::iter::repeat(false).take(left_on.len()).collect(), + num_partitions, + )) + } + if !is_right_sort_partitioned { + right_physical = PhysicalPlan::Sort(Sort::new( + right_physical.into(), + right_on.clone(), + std::iter::repeat(false).take(right_on.len()).collect(), + num_partitions, + )) + } + false + }; + Ok(PhysicalPlan::SortMergeJoin(SortMergeJoin::new( + left_physical.into(), + right_physical.into(), + left_on.clone(), + right_on.clone(), + *join_type, + num_partitions, + left_is_larger, + needs_presort, + ))) + } + JoinStrategy::Hash => { + if (num_partitions > 1 || left_pspec.num_partitions != num_partitions) + && !is_left_hash_partitioned + { + let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( + left_physical.into(), + num_partitions, + left_on.clone(), + )); + left_physical = + PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())); + } + if (num_partitions > 1 || right_pspec.num_partitions != num_partitions) + && !is_right_hash_partitioned + { + let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( + right_physical.into(), + num_partitions, + right_on.clone(), + )); + right_physical = + PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())); + } + Ok(PhysicalPlan::HashJoin(HashJoin::new( + left_physical.into(), + right_physical.into(), + left_on.clone(), + right_on.clone(), + *join_type, + ))) } - return Ok(PhysicalPlan::BroadcastJoin(BroadcastJoin::new(left_physical.into(), right_physical.into(), left_on.clone(), right_on.clone(), *join_type, do_swap))); - } - if (num_partitions > 1 || left_pspec.num_partitions != num_partitions) - && !is_left_partitioned - { - let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( - left_physical.into(), - num_partitions, - left_on.clone(), - )); - left_physical = PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())); - } - if (num_partitions > 1 || right_pspec.num_partitions != num_partitions) - && !is_right_partitioned - { - let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( - right_physical.into(), - num_partitions, - right_on.clone(), - )); - right_physical = PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())); } - Ok(PhysicalPlan::HashJoin(HashJoin::new( - left_physical.into(), - right_physical.into(), - left_on.clone(), - right_on.clone(), - *join_type, - ))) } LogicalPlan::Sink(LogicalSink { schema, diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 9cb4bb96a7..ef2a4d2a8a 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -1,4 +1,5 @@ #![feature(hash_raw_entry)] +#![feature(let_chains)] use std::borrow::Cow; use std::collections::HashSet; diff --git a/src/daft-table/src/ops/joins/merge_join.rs b/src/daft-table/src/ops/joins/merge_join.rs new file mode 100644 index 0000000000..1cd3795819 --- /dev/null +++ b/src/daft-table/src/ops/joins/merge_join.rs @@ -0,0 +1,243 @@ +use std::cmp::Ordering; + +use daft_core::{ + array::ops::full::FullNull, + datatypes::{DataType, UInt64Array}, + kernels::search_sorted::build_partial_compare_with_nulls, + series::{IntoSeries, Series}, +}; + +use crate::Table; +use common_error::{DaftError, DaftResult}; + +/// A state machine for the below merge-join algorithm. +/// +/// This state machine is in reference to a pair of left/right pointers into left/right tables of the merge-join. +/// +/// Valid state transitions (initial state is Mismatch): +/// +/// ANY -> BothNull +/// {Mismatch, BothNull, StagedLeftEqualRun, StagedRightEqualRun} -> Mismatch +/// ANY -> {LeftEqualRun, RightEqualRun} +/// LeftEqualRun -> StagedLeftEqualRun +/// RightEqualRun -> StagedRightEqualRun +#[derive(Debug)] +enum MergeJoinState { + // Previous pair of rows were not equal. + Mismatch, + // Previous pair of rows were incomparable, i.e. for one or more join keys, they both had null values. + BothNull, + // Currently on a left-side equality run, where the left-side run started at the stored index and is relative to + // a fixed right-side row. + LeftEqualRun(usize), + // Currently on a right-side equality run, where the right-side run started at the stored index and is relative to + // a fixed left-side row. + RightEqualRun(usize), + // A staged left-side equality run starting at the stored index and ending at the current left pointer; this run + // may be equal to one or more future right-side rows. + StagedLeftEqualRun(usize), + // A staged right-side equality run starting at the stored index and ending at the current right pointer; this run + // may be equal to one or more future left-side rows. + StagedRightEqualRun(usize), +} + +pub fn merge_inner_join(left: &Table, right: &Table) -> DaftResult<(Series, Series)> { + if left.num_columns() != right.num_columns() { + return Err(DaftError::ValueError(format!( + "Mismatch of join on clauses: left: {:?} vs right: {:?}", + left.num_columns(), + right.num_columns() + ))); + } + if left.num_columns() == 0 { + return Err(DaftError::ValueError( + "No columns were passed in to join on".to_string(), + )); + } + + // Short-circuit if any of the join keys are all-null (i.e. have the null dtype). + let has_null_type = left.columns.iter().any(|s| s.data_type().is_null()) + || right.columns.iter().any(|s| s.data_type().is_null()); + if has_null_type { + return Ok(( + UInt64Array::empty("left_indices", &DataType::UInt64).into_series(), + UInt64Array::empty("right_indices", &DataType::UInt64).into_series(), + )); + } + let types_not_match = left + .columns + .iter() + .zip(right.columns.iter()) + .any(|(l, r)| l.data_type() != r.data_type()); + if types_not_match { + return Err(DaftError::SchemaMismatch( + "Types between left and right do not match".to_string(), + )); + } + + // TODO(Clark): If one of the tables is much larger than the other, iterate through smaller table while doing + // binary search on the larger table. + + // Construct comparator over all join keys. + let mut cmp_list = Vec::with_capacity(left.num_columns()); + for (left_series, right_series) in left.columns.iter().zip(right.columns.iter()) { + cmp_list.push(build_partial_compare_with_nulls( + left_series.to_arrow().as_ref(), + right_series.to_arrow().as_ref(), + false, + )?); + } + let combined_comparator = |a_idx: usize, b_idx: usize| -> Option { + for comparator in cmp_list.iter() { + match comparator(a_idx, b_idx) { + Some(Ordering::Equal) => continue, + other => return other, + } + } + Some(Ordering::Equal) + }; + + // Short-circuit if tables are empty or range-wise disjoint on join keys. + if left.is_empty() + || right.is_empty() + || matches!(combined_comparator(left.len() - 1, 0), Some(Ordering::Less)) + || matches!( + combined_comparator(0, right.len() - 1), + Some(Ordering::Greater) + ) + { + return Ok(( + UInt64Array::empty("left_indices", &DataType::UInt64).into_series(), + UInt64Array::empty("right_indices", &DataType::UInt64).into_series(), + )); + } + + // Perform the merge by building up left-side and right-side take index vectors. + let mut left_indices = vec![]; + let mut right_indices = vec![]; + let mut left_idx = 0; + let mut right_idx = 0; + + // The current state of the merge-join. + let mut state = MergeJoinState::Mismatch; + while left_idx < left.len() && right_idx < right.len() { + match combined_comparator(left_idx, right_idx) { + // Left row is less than right row, so need to move to next left row for potential match. + Some(Ordering::Less) => { + state = match state { + // If we previously had a right-side run of rows equal to a fixed left-side row, we move the right + // pointer back to the last row of that run, and stage the run for the comparison of said last row + // of the run with the next left-side row. + // If that next comparison comes out to be equal, we will do a bulk push of the right-side run with + // the new left-side row without having to compare the new left-side row with every row in the + // right-side run. + MergeJoinState::RightEqualRun(start_right_idx) => { + right_idx -= 1; + MergeJoinState::StagedRightEqualRun(start_right_idx) + } + // Any other previous states shouldn't matter going forward, so this is a plain mismatch. + _ => MergeJoinState::Mismatch, + }; + left_idx += 1; + } + // Right row is less than left row, so need to move to next right row for potential match. + Some(Ordering::Greater) => { + state = match state { + // If we previously had a left-side run of rows equal to a fixed right-side row, we move the left + // pointer back to the last row of that run, and stage the run for the comparison of said last + // row of the run with the next right-side row. + // If that next comparison comes out to be equal, we will do a bulk push of the left-side run with + // the new right-side row without having to compare the new right-side row with every row in the + // left-side run. + MergeJoinState::LeftEqualRun(start_left_idx) => { + left_idx -= 1; + MergeJoinState::StagedLeftEqualRun(start_left_idx) + } + // Any other previous states shouldn't matter going forward, so this is a plain mismatch. + _ => MergeJoinState::Mismatch, + }; + right_idx += 1; + } + // Left row is equal to the right row, so we need to add this pair of indices to the output indices. + Some(Ordering::Equal) => { + // First, handle past equal runs in bulk as a comparison-eliding optimization. + match state { + // If there was a staged left-side run, then we know that all rows in the run is equal to this + // new right-side row, so we add all such pairs to the output indices without any extra comparisons. + MergeJoinState::StagedLeftEqualRun(start_left_idx) => { + left_indices.extend((start_left_idx..left_idx).map(|i| i as u64)); + right_indices.extend( + std::iter::repeat(right_idx as u64).take(left_idx - start_left_idx), + ); + } + // If there was a staged right-side run, then we know that all rows in the run is equal to this + // new left-side row, so we add all such pairs to the output indices without any extra comparisons. + MergeJoinState::StagedRightEqualRun(start_right_idx) => { + left_indices.extend( + std::iter::repeat(left_idx as u64).take(right_idx - start_right_idx), + ); + right_indices.extend((start_right_idx..right_idx).map(|i| i as u64)); + } + _ => {} + } + // Add current pointer pair to output indices. + left_indices.push(left_idx as u64); + right_indices.push(right_idx as u64); + // Update state. + state = match state { + // If already in a left-side equality run but we've reached the end of the left-side table, + // we can't extend the run anymore, so we stage it. + MergeJoinState::LeftEqualRun(start_left_idx) if left_idx == left.len() - 1 => { + MergeJoinState::StagedLeftEqualRun(start_left_idx) + } + // If already in a right-side equality run but we've reached the end of the right-side table, + // we can't extend the run anymore, so we stage it. + MergeJoinState::RightEqualRun(start_right_idx) + if right_idx == left.len() - 1 => + { + MergeJoinState::StagedRightEqualRun(start_right_idx) + } + // If already in or just used a left-/right-side equality run and we're not at the end of the + // corresponding table, this equal comparison extends a current run or suggests we can keep + // applying a staged run; in either case, the state is unchanged. + MergeJoinState::LeftEqualRun(_) + | MergeJoinState::RightEqualRun(_) + | MergeJoinState::StagedLeftEqualRun(_) + | MergeJoinState::StagedRightEqualRun(_) => state, + // If coming from a non-active equality run, start a new run. + MergeJoinState::Mismatch | MergeJoinState::BothNull => { + // We assume that larger tables will have longer equality runs. + if left.len() >= right.len() { + MergeJoinState::LeftEqualRun(left_idx) + } else { + MergeJoinState::RightEqualRun(right_idx) + } + } + }; + // Move the pointer forward for the appropriate side ofthe join. + match state { + // If extending a left-side run or propagating an existing right-side run, move left pointer forward. + MergeJoinState::LeftEqualRun(_) | MergeJoinState::StagedRightEqualRun(_) => { + left_idx += 1 + } + // If extending a right-side run or propagating an existing left-side run, move right pointer forward. + MergeJoinState::RightEqualRun(_) | MergeJoinState::StagedLeftEqualRun(_) => { + right_idx += 1 + } + _ => unreachable!(), + } + } + // Rows are not comparable, i.e. both rows are null for at least one of the join keys. + None => { + // Update the state. + state = MergeJoinState::BothNull; + // Advance past the nulls. + left_idx += 1; + right_idx += 1; + } + } + } + let left_series = UInt64Array::from(("left_indices", left_indices)); + let right_series = UInt64Array::from(("right_indices", right_indices)); + Ok((left_series.into_series(), right_series.into_series())) +} diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index 081b3aa52a..bf69bd11c4 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -1,6 +1,6 @@ use std::collections::{HashMap, HashSet}; -use daft_core::{schema::Schema, utils::supertype::try_get_supertype}; +use daft_core::{schema::Schema, utils::supertype::try_get_supertype, Series}; use common_error::{DaftError, DaftResult}; use daft_dsl::Expr; @@ -8,6 +8,7 @@ use daft_dsl::Expr; use crate::Table; mod hash_join; +mod merge_join; fn match_types_for_tables(left: &Table, right: &Table) -> DaftResult<(Table, Table)> { let mut lseries = vec![]; @@ -103,15 +104,64 @@ pub fn infer_join_schema( } impl Table { - pub fn join(&self, right: &Self, left_on: &[Expr], right_on: &[Expr]) -> DaftResult { + pub fn hash_join(&self, right: &Self, left_on: &[Expr], right_on: &[Expr]) -> DaftResult { + self.join(right, left_on, right_on, hash_join::hash_inner_join) + } + + pub fn sort_merge_join( + &self, + right: &Self, + left_on: &[Expr], + right_on: &[Expr], + is_sorted: bool, + ) -> DaftResult { + if is_sorted { + self.join(right, left_on, right_on, merge_join::merge_inner_join) + } else { + if left_on.is_empty() { + return Err(DaftError::ValueError( + "No columns were passed in to join on".to_string(), + )); + } + let left = self.sort( + left_on, + std::iter::repeat(false) + .take(left_on.len()) + .collect::>() + .as_slice(), + )?; + if right_on.is_empty() { + return Err(DaftError::ValueError( + "No columns were passed in to join on".to_string(), + )); + } + let right = right.sort( + right_on, + std::iter::repeat(false) + .take(right_on.len()) + .collect::>() + .as_slice(), + )?; + left.join(&right, left_on, right_on, merge_join::merge_inner_join) + } + } + + fn join( + &self, + right: &Self, + left_on: &[Expr], + right_on: &[Expr], + inner_join: impl Fn(&Table, &Table) -> DaftResult<(Series, Series)>, + ) -> DaftResult { let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?; + if self.is_empty() || right.is_empty() { + return Self::empty(Some(join_schema.into())); + } let ltable = self.eval_expression_list(left_on)?; let rtable = right.eval_expression_list(right_on)?; let (ltable, rtable) = match_types_for_tables(<able, &rtable)?; - - let (lidx, ridx) = hash_join::hash_inner_join(<able, &rtable)?; - + let (lidx, ridx) = inner_join(<able, &rtable)?; let mut join_fields = ltable .column_names() .iter() @@ -131,6 +181,7 @@ impl Table { names_so_far.insert(f.name.clone()); }); + // TODO(Clark): Parallelize with rayon. for field in self.schema.fields.values() { if names_so_far.contains(&field.name) { continue; @@ -154,6 +205,7 @@ impl Table { let right_to_left_keys: HashMap<&str, &str> = HashMap::from_iter(zipped_names.iter().copied()); + // TODO(Clark): Parallelize with Rayon. for field in right.schema.fields.values() { // Skip fields if they were used in the join and have the same name as the corresponding left field match right_to_left_keys.get(field.name.as_str()) { diff --git a/src/daft-table/src/python.rs b/src/daft-table/src/python.rs index 24b2a568fe..98601adf38 100644 --- a/src/daft-table/src/python.rs +++ b/src/daft-table/src/python.rs @@ -95,7 +95,7 @@ impl PyTable { }) } - pub fn join( + pub fn hash_join( &self, py: Python, right: &Self, @@ -107,7 +107,30 @@ impl PyTable { py.allow_threads(|| { Ok(self .table - .join(&right.table, left_exprs.as_slice(), right_exprs.as_slice())? + .hash_join(&right.table, left_exprs.as_slice(), right_exprs.as_slice())? + .into()) + }) + } + + pub fn sort_merge_join( + &self, + py: Python, + right: &Self, + left_on: Vec, + right_on: Vec, + is_sorted: bool, + ) -> PyResult { + let left_exprs: Vec = left_on.into_iter().map(|e| e.into()).collect(); + let right_exprs: Vec = right_on.into_iter().map(|e| e.into()).collect(); + py.allow_threads(|| { + Ok(self + .table + .sort_merge_join( + &right.table, + left_exprs.as_slice(), + right_exprs.as_slice(), + is_sorted, + )? .into()) }) } diff --git a/tests/conftest.py b/tests/conftest.py index 0515611cc4..8262d48f9c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,6 +64,23 @@ def data_source(request): return request.param +@pytest.fixture(scope="function") +def join_strategy(request): + # Modifies the join strategy parametrization to toggle a specialized presorting path for sort-merge joins, where + # each side of the join is sorted such that their boundaries will align. + if request.param != "sort_merge_aligned_boundaries": + yield request.param + else: + old_execution_config = daft.context.get_context().daft_execution_config + try: + daft.set_execution_config( + sort_merge_join_sort_with_aligned_boundaries=True, + ) + yield "sort_merge" + finally: + daft.set_execution_config(old_execution_config) + + @pytest.fixture(scope="function") def make_df(data_source, tmp_path) -> daft.Dataframe: """Makes a dataframe when provided with data""" diff --git a/tests/cookbook/test_joins.py b/tests/cookbook/test_joins.py index 4ec962e328..cbc4743d0b 100644 --- a/tests/cookbook/test_joins.py +++ b/tests/cookbook/test_joins.py @@ -1,14 +1,19 @@ from __future__ import annotations +import pytest + from daft.expressions import col from tests.conftest import assert_df_equals -def test_simple_join(daft_df, service_requests_csv_pd_df, repartition_nparts): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_simple_join(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): daft_df = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.select(col("Unique Key"), col("Borough")) daft_df_right = daft_df.select(col("Unique Key"), col("Created Date")) - daft_df = daft_df_left.join(daft_df_right, col("Unique Key")) + daft_df = daft_df_left.join(daft_df_right, col("Unique Key"), strategy=join_strategy) service_requests_csv_pd_df_left = service_requests_csv_pd_df[["Unique Key", "Borough"]] service_requests_csv_pd_df_right = service_requests_csv_pd_df[["Unique Key", "Created Date"]] @@ -21,11 +26,14 @@ def test_simple_join(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df) -def test_simple_self_join(daft_df, service_requests_csv_pd_df, repartition_nparts): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): daft_df = daft_df.repartition(repartition_nparts) daft_df = daft_df.select(col("Unique Key"), col("Borough")) - daft_df = daft_df.join(daft_df, col("Unique Key")) + daft_df = daft_df.join(daft_df, col("Unique Key"), strategy=join_strategy) service_requests_csv_pd_df = service_requests_csv_pd_df[["Unique Key", "Borough"]] service_requests_csv_pd_df = ( @@ -38,12 +46,15 @@ def test_simple_self_join(daft_df, service_requests_csv_pd_df, repartition_npart assert_df_equals(daft_pd_df, service_requests_csv_pd_df) -def test_simple_join_missing_rvalues(daft_df, service_requests_csv_pd_df, repartition_nparts): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_simple_join_missing_rvalues(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): daft_df_right = daft_df.sort("Unique Key").limit(25).repartition(repartition_nparts) daft_df_left = daft_df.repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) daft_df_right = daft_df_right.select(col("Unique Key"), col("Created Date")).sort(col("Unique Key")) - daft_df = daft_df_left.join(daft_df_right, col("Unique Key")) + daft_df = daft_df_left.join(daft_df_right, col("Unique Key"), strategy=join_strategy) service_requests_csv_pd_df_left = service_requests_csv_pd_df[["Unique Key", "Borough"]] service_requests_csv_pd_df_right = ( @@ -58,12 +69,15 @@ def test_simple_join_missing_rvalues(daft_df, service_requests_csv_pd_df, repart assert_df_equals(daft_pd_df, service_requests_csv_pd_df) -def test_simple_join_missing_lvalues(daft_df, service_requests_csv_pd_df, repartition_nparts): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_simple_join_missing_lvalues(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): daft_df_right = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.sort(col("Unique Key")).limit(25).repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) daft_df_right = daft_df_right.select(col("Unique Key"), col("Created Date")) - daft_df = daft_df_left.join(daft_df_right, col("Unique Key")) + daft_df = daft_df_left.join(daft_df_right, col("Unique Key"), strategy=join_strategy) service_requests_csv_pd_df_left = ( service_requests_csv_pd_df[["Unique Key", "Borough"]].sort_values(by="Unique Key").head(25) diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 5c1f61539c..99713c149d 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -3,29 +3,43 @@ import pyarrow as pa import pytest -import daft from daft.datatype import DataType from daft.errors import ExpressionTypeError from tests.utils import sort_arrow_table -@pytest.fixture(params=[False, True]) -def broadcast_join_enabled(request): - # Toggles between default broadcast join threshold (10 MiB), and a threshold of 0, which disables broadcast joins. - broadcast_threshold = 10 * 1024 * 1024 if request.param else 0 +@pytest.mark.parametrize("n_partitions", [1, 2, 4]) +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_joins(join_strategy, make_df, n_partitions: int): + df = make_df( + { + "A": [1, 2, 3], + "B": ["a", "b", "c"], + }, + repartition=n_partitions, + repartition_columns=["A"], + ) + + joined = df.join(df, on="A", strategy=join_strategy) + # We shouldn't need to sort the joined output if using a sort-merge join. + if join_strategy != "sort_merge": + joined = joined.sort("A") + joined_data = joined.to_pydict() - old_execution_config = daft.context.get_context().daft_execution_config - try: - daft.set_execution_config( - broadcast_join_size_bytes_threshold=broadcast_threshold, - ) - yield - finally: - daft.set_execution_config(old_execution_config) + assert joined_data == { + "A": [1, 2, 3], + "B": ["a", "b", "c"], + "right.B": ["a", "b", "c"], + } @pytest.mark.parametrize("n_partitions", [1, 2, 4]) -def test_multicol_joins(broadcast_join_enabled, make_df, n_partitions: int): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_multicol_joins(join_strategy, make_df, n_partitions: int): df = make_df( { "A": [1, 2, 3], @@ -36,7 +50,10 @@ def test_multicol_joins(broadcast_join_enabled, make_df, n_partitions: int): repartition_columns=["A", "B"], ) - joined = df.join(df, on=["A", "B"]).sort("A") + joined = df.join(df, on=["A", "B"], strategy=join_strategy) + # We shouldn't need to sort the joined output if using a sort-merge join. + if join_strategy != "sort_merge": + joined = joined.sort("A") joined_data = joined.to_pydict() assert joined_data == { @@ -47,8 +64,172 @@ def test_multicol_joins(broadcast_join_enabled, make_df, n_partitions: int): } +@pytest.mark.parametrize("n_partitions", [1, 2, 4, 8]) +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_dupes_join_key(join_strategy, make_df, n_partitions: int): + df = make_df( + { + "A": [1, 1, 2, 2, 3, 3], + "B": ["a", "b", "c", "d", "e", "f"], + }, + repartition=n_partitions, + repartition_columns=["A"], + ) + + joined = df.join(df, on="A", strategy=join_strategy) + # We shouldn't need to sort the joined output if using a sort-merge join. + if join_strategy != "sort_merge": + joined = joined.sort("A") + joined_data = joined.to_pydict() + + assert joined_data == { + "A": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], + "B": ["a", "b", "a", "b", "c", "d", "c", "d", "e", "f", "e", "f"], + "right.B": ["a", "a", "b", "b", "c", "c", "d", "d", "e", "e", "f", "f"], + } + + +@pytest.mark.parametrize("n_partitions", [1, 2, 4, 8]) +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_multicol_dupes_join_key(join_strategy, make_df, n_partitions: int): + df = make_df( + { + "A": [1, 1, 2, 2, 3, 3], + "B": ["a", "a", "b", "b", "c", "d"], + "C": [True, False, True, False, True, False], + }, + repartition=n_partitions, + repartition_columns=["A", "B"], + ) + + joined = df.join(df, on=["A", "B"], strategy=join_strategy) + # We shouldn't need to sort the joined output if using a sort-merge join. + if join_strategy != "sort_merge": + joined = joined.sort(["A", "B"]) + joined_data = joined.to_pydict() + + assert joined_data == { + "A": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3], + "B": ["a"] * 4 + ["b"] * 4 + ["c", "d"], + "C": [True, False, True, False, True, False, True, False, True, False], + "right.C": [True, True, False, False, True, True, False, False, True, False], + } + + +@pytest.mark.parametrize("n_partitions", [1, 2, 4, 6]) +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_joins_all_same_key(join_strategy, make_df, n_partitions: int): + df = make_df( + { + "A": [1] * 4, + "B": ["a", "b", "c", "d"], + }, + repartition=n_partitions, + repartition_columns=["A"], + ) + + joined = df.join(df, on="A", strategy=join_strategy) + # We shouldn't need to sort the joined output if using a sort-merge join. + if join_strategy != "sort_merge": + joined = joined.sort("A") + joined_data = joined.to_pydict() + + assert joined_data == { + "A": [1] * 16, + "B": ["a", "b", "c", "d"] * 4, + "right.B": ["a"] * 4 + ["b"] * 4 + ["c"] * 4 + ["d"] * 4, + } + + +@pytest.mark.parametrize("n_partitions", [1, 2, 4]) +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +@pytest.mark.parametrize("flip", [False, True]) +def test_joins_no_overlap_disjoint(join_strategy, make_df, n_partitions: int, flip): + df1 = make_df( + { + "A": [1, 2, 3], + "B": ["a", "b", "c"], + }, + repartition=n_partitions, + repartition_columns=["A"], + ) + df2 = make_df( + { + "A": [4, 5, 6], + "B": ["d", "e", "f"], + }, + repartition=n_partitions, + repartition_columns=["A"], + ) + + if flip: + joined = df2.join(df1, on="A", strategy=join_strategy) + else: + joined = df1.join(df2, on="A", strategy=join_strategy) + # We shouldn't need to sort the joined output if using a sort-merge join. + if join_strategy != "sort_merge": + joined = joined.sort("A") + joined_data = joined.to_pydict() + + assert joined_data == { + "A": [], + "B": [], + "right.B": [], + } + + +@pytest.mark.parametrize("n_partitions", [1, 2, 4]) +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +@pytest.mark.parametrize("flip", [False, True]) +def test_joins_no_overlap_interleaved(join_strategy, make_df, n_partitions: int, flip): + df1 = make_df( + { + "A": [1, 3, 5], + "B": ["a", "b", "c"], + }, + repartition=n_partitions, + repartition_columns=["A"], + ) + df2 = make_df( + { + "A": [2, 4, 6], + "B": ["d", "e", "f"], + }, + repartition=n_partitions, + repartition_columns=["A"], + ) + + if flip: + joined = df2.join(df1, on="A", strategy=join_strategy) + else: + joined = df1.join(df2, on="A", strategy=join_strategy) + # We shouldn't need to sort the joined output if using a sort-merge join. + if join_strategy != "sort_merge": + joined = joined.sort("A") + joined_data = joined.to_pydict() + + assert joined_data == { + "A": [], + "B": [], + "right.B": [], + } + + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) -def test_limit_after_join(broadcast_join_enabled, make_df, n_partitions: int): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_limit_after_join(join_strategy, make_df, n_partitions: int): data = { "A": [1, 2, 3], } @@ -63,7 +244,7 @@ def test_limit_after_join(broadcast_join_enabled, make_df, n_partitions: int): repartition_columns=["A"], ) - joined = df1.join(df2, on="A").limit(1) + joined = df1.join(df2, on="A", strategy=join_strategy).limit(1) joined_data = joined.to_pydict() assert "A" in joined_data assert len(joined_data["A"]) == 1 @@ -75,7 +256,10 @@ def test_limit_after_join(broadcast_join_enabled, make_df, n_partitions: int): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join(broadcast_join_enabled, make_df, repartition_nparts): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_inner_join(join_strategy, make_df, repartition_nparts): daft_df = make_df( { "id": [1, None, 3], @@ -90,7 +274,7 @@ def test_inner_join(broadcast_join_enabled, make_df, repartition_nparts): }, repartition=repartition_nparts, ) - daft_df = daft_df.join(daft_df2, on="id", how="inner") + daft_df = daft_df.join(daft_df2, on="id", how="inner", strategy=join_strategy) expected = { "id": [1, 3], @@ -103,7 +287,10 @@ def test_inner_join(broadcast_join_enabled, make_df, repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join_multikey(broadcast_join_enabled, make_df, repartition_nparts): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_inner_join_multikey(join_strategy, make_df, repartition_nparts): daft_df = make_df( { "id": [1, None, None], @@ -120,7 +307,7 @@ def test_inner_join_multikey(broadcast_join_enabled, make_df, repartition_nparts }, repartition=repartition_nparts, ) - daft_df = daft_df.join(daft_df2, on=["id", "id2"], how="inner") + daft_df = daft_df.join(daft_df2, on=["id", "id2"], how="inner", strategy=join_strategy) expected = { "id": [1], @@ -134,7 +321,52 @@ def test_inner_join_multikey(broadcast_join_enabled, make_df, repartition_nparts @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join_all_null(broadcast_join_enabled, make_df, repartition_nparts): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_inner_join_asymmetric_multikey(join_strategy, make_df, repartition_nparts): + daft_df = make_df( + { + "left_id": [1, None, None], + "left_id2": ["foo1", "foo2", None], + "values_left": ["a1", "b1", "c1"], + }, + repartition=repartition_nparts, + ) + daft_df2 = make_df( + { + "right_id": [None, None, 1], + "right_id2": ["foo2", None, "foo1"], + "values_right": ["a2", "b2", "c2"], + }, + repartition=repartition_nparts, + ) + daft_df = daft_df.join( + daft_df2, + left_on=["left_id", "left_id2"], + right_on=["right_id", "right_id2"], + how="inner", + strategy=join_strategy, + ) + + expected = { + "left_id": [1], + "left_id2": ["foo1"], + "values_left": ["a1"], + "right_id": [1], + "right_id2": ["foo1"], + "values_right": ["c2"], + } + assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "left_id") == sort_arrow_table( + pa.Table.from_pydict(expected), "left_id" + ) + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_inner_join_all_null(join_strategy, make_df, repartition_nparts): daft_df = make_df( { "id": [None, None, None], @@ -149,7 +381,9 @@ def test_inner_join_all_null(broadcast_join_enabled, make_df, repartition_nparts }, repartition=repartition_nparts, ) - daft_df = daft_df.with_column("id", daft_df["id"].cast(DataType.int64())).join(daft_df2, on="id", how="inner") + daft_df = daft_df.with_column("id", daft_df["id"].cast(DataType.int64())).join( + daft_df2, on="id", how="inner", strategy=join_strategy + ) expected = { "id": [], @@ -161,7 +395,10 @@ def test_inner_join_all_null(broadcast_join_enabled, make_df, repartition_nparts ) -def test_inner_join_null_type_column(broadcast_join_enabled, make_df): +@pytest.mark.parametrize( + "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True +) +def test_inner_join_null_type_column(join_strategy, make_df): daft_df = make_df( { "id": [None, None, None], @@ -176,4 +413,4 @@ def test_inner_join_null_type_column(broadcast_join_enabled, make_df): ) with pytest.raises((ExpressionTypeError, ValueError)): - daft_df.join(daft_df2, on="id", how="inner") + daft_df.join(daft_df2, on="id", how="inner", strategy=join_strategy) diff --git a/tests/table/test_joins.py b/tests/table/test_joins.py index e770340572..e2df942fbe 100644 --- a/tests/table/test_joins.py +++ b/tests/table/test_joins.py @@ -40,13 +40,16 @@ ], ), ) -def test_table_join_single_column(dtype, data) -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_single_column(join_impl, dtype, data) -> None: l, r, expected_pairs = data left_table = MicroPartition.from_pydict({"x": l, "x_ind": list(range(len(l)))}).eval_expression_list( [col("x").cast(dtype), col("x_ind")] ) right_table = MicroPartition.from_pydict({"y": r, "y_ind": list(range(len(r)))}) - result_table = left_table.join(right_table, left_on=[col("x")], right_on=[col("y")], how=JoinType.Inner) + result_table = getattr(left_table, join_impl)( + right_table, left_on=[col("x")], right_on=[col("y")], how=JoinType.Inner + ) assert result_table.column_names() == ["x", "x_ind", "y", "y_ind"] @@ -61,7 +64,9 @@ def test_table_join_single_column(dtype, data) -> None: assert result_table.get_column("y").to_pylist() == result_r # make sure the result is the same with right table on left - result_table = right_table.join(left_table, right_on=[col("x")], left_on=[col("y")], how=JoinType.Inner) + result_table = getattr(right_table, join_impl)( + left_table, right_on=[col("x")], left_on=[col("y")], how=JoinType.Inner + ) assert result_table.column_names() == ["y", "y_ind", "x", "x_ind"] @@ -76,12 +81,13 @@ def test_table_join_single_column(dtype, data) -> None: assert result_table.get_column("y").to_pylist() == result_r -def test_table_join_mismatch_column() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_mismatch_column(join_impl) -> None: left_table = MicroPartition.from_pydict({"x": [1, 2, 3, 4], "y": [2, 3, 4, 5]}) right_table = MicroPartition.from_pydict({"a": [1, 2, 3, 4], "b": [2, 3, 4, 5]}) with pytest.raises(ValueError, match="Mismatch of number of join keys"): - left_table.join(right_table, left_on=[col("x"), col("y")], right_on=[col("a")]) + getattr(left_table, join_impl)(right_table, left_on=[col("x"), col("y")], right_on=[col("a")]) @pytest.mark.parametrize( @@ -98,7 +104,8 @@ def test_table_join_mismatch_column() -> None: {"x": ["banana", "apple"], "y": [3, 4]}, ], ) -def test_table_join_multicolumn_empty_result(left, right) -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_multicolumn_empty_result(join_impl, left, right) -> None: """Various multicol joins that should all produce an empty result.""" left_table = MicroPartition.from_pydict(left).eval_expression_list( [col("a").cast(DataType.string()), col("b").cast(DataType.int32())] @@ -107,11 +114,12 @@ def test_table_join_multicolumn_empty_result(left, right) -> None: [col("x").cast(DataType.string()), col("y").cast(DataType.int32())] ) - result = left_table.join(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) + result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) assert result.to_pydict() == {"a": [], "b": [], "x": [], "y": []} -def test_table_join_multicolumn_nocross() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_multicolumn_nocross(join_impl) -> None: """A multicol join that should produce two rows and no cross product results. Input has duplicate join values and overlapping single-column values, @@ -132,7 +140,7 @@ def test_table_join_multicolumn_nocross() -> None: } ) - result = left_table.join(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) + result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set( utils.freeze( [ @@ -143,7 +151,8 @@ def test_table_join_multicolumn_nocross() -> None: ) -def test_table_join_multicolumn_cross() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_multicolumn_cross(join_impl) -> None: """A multicol join that should produce a cross product and a non-cross product.""" left_table = MicroPartition.from_pydict( @@ -161,7 +170,7 @@ def test_table_join_multicolumn_cross() -> None: } ) - result = left_table.join(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) + result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set( utils.freeze( [ @@ -178,7 +187,8 @@ def test_table_join_multicolumn_cross() -> None: ) -def test_table_join_multicolumn_all_nulls() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_multicolumn_all_nulls(join_impl) -> None: left_table = MicroPartition.from_pydict( { "a": Series.from_pylist([None, None, None]).cast(DataType.int64()), @@ -194,23 +204,25 @@ def test_table_join_multicolumn_all_nulls() -> None: } ) - result = left_table.join(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) + result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set(utils.freeze([])) -def test_table_join_no_columns() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_no_columns(join_impl) -> None: left_table = MicroPartition.from_pydict({"x": [1, 2, 3, 4], "y": [2, 3, 4, 5]}) right_table = MicroPartition.from_pydict({"a": [1, 2, 3, 4], "b": [2, 3, 4, 5]}) with pytest.raises(ValueError, match="No columns were passed in to join on"): - left_table.join(right_table, left_on=[], right_on=[]) + getattr(left_table, join_impl)(right_table, left_on=[], right_on=[]) -def test_table_join_single_column_name_conflicts() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_single_column_name_conflicts(join_impl) -> None: left_table = MicroPartition.from_pydict({"x": [0, 1, 2, 3], "y": [2, 3, 4, 5]}) right_table = MicroPartition.from_pydict({"x": [3, 2, 1, 0], "y": [6, 7, 8, 9]}) - result_table = left_table.join(right_table, left_on=[col("x")], right_on=[col("x")]) + result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) assert result_table.column_names() == ["x", "y", "right.y"] result_sorted = result_table.sort([col("x")]) assert result_sorted.get_column("y").to_pylist() == [2, 3, 4, 5] @@ -218,11 +230,12 @@ def test_table_join_single_column_name_conflicts() -> None: assert result_sorted.get_column("right.y").to_pylist() == [9, 8, 7, 6] -def test_table_join_single_column_name_conflicts_different_named_join() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_single_column_name_conflicts_different_named_join(join_impl) -> None: left_table = MicroPartition.from_pydict({"x": [0, 1, 2, 3], "y": [2, 3, 4, 5]}) right_table = MicroPartition.from_pydict({"y": [3, 2, 1, 0], "x": [6, 7, 8, 9]}) - result_table = left_table.join(right_table, left_on=[col("x")], right_on=[col("y")]) + result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("y")]) # NOTE: right.y is not dropped because it has a different name from the corresponding left # column it is joined on, left_table["x"] @@ -233,11 +246,12 @@ def test_table_join_single_column_name_conflicts_different_named_join() -> None: assert result_sorted.get_column("right.x").to_pylist() == [9, 8, 7, 6] -def test_table_join_single_column_name_multiple_conflicts() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_single_column_name_multiple_conflicts(join_impl) -> None: left_table = MicroPartition.from_pydict({"x": [0, 1, 2, 3], "y": [2, 3, 4, 5], "right.y": [6, 7, 8, 9]}) right_table = MicroPartition.from_pydict({"x": [3, 2, 1, 0], "y": [10, 11, 12, 13]}) - result_table = left_table.join(right_table, left_on=[col("x")], right_on=[col("x")]) + result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) assert result_table.column_names() == ["x", "y", "right.y", "right.right.y"] result_sorted = result_table.sort([col("x")]) assert result_sorted.get_column("y").to_pylist() == [2, 3, 4, 5] @@ -246,22 +260,24 @@ def test_table_join_single_column_name_multiple_conflicts() -> None: assert result_sorted.get_column("right.right.y").to_pylist() == [13, 12, 11, 10] -def test_table_join_single_column_name_boolean() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_single_column_name_boolean(join_impl) -> None: left_table = MicroPartition.from_pydict({"x": [False, True, None], "y": [0, 1, 2]}) right_table = MicroPartition.from_pydict({"x": [None, True, False, None], "y": [0, 1, 2, 3]}) - result_table = left_table.join(right_table, left_on=[col("x")], right_on=[col("x")]) + result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) assert result_table.column_names() == ["x", "y", "right.y"] result_sorted = result_table.sort([col("x")]) assert result_sorted.get_column("y").to_pylist() == [0, 1] assert result_sorted.get_column("right.y").to_pylist() == [2, 1] -def test_table_join_single_column_name_null() -> None: +@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) +def test_table_join_single_column_name_null(join_impl) -> None: left_table = MicroPartition.from_pydict({"x": [None, None, None], "y": [0, 1, 2]}) right_table = MicroPartition.from_pydict({"x": [None, None, None, None], "y": [0, 1, 2, 3]}) - result_table = left_table.join(right_table, left_on=[col("x")], right_on=[col("x")]) + result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) assert result_table.column_names() == ["x", "y", "right.y"] result_sorted = result_table.sort([col("x")]) assert result_sorted.get_column("y").to_pylist() == []