From 2739cc06e6a77768fe036a72928b5b22218e8c49 Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Mon, 11 Dec 2023 11:30:56 -0800 Subject: [PATCH] [FEAT] [Join Optimizations] Add broadcast join. (#1706) This PR adds a broadcast join implementation as a new join strategy, where all partitions of a small table are broadcasted to each partition in the larger table, such that we do a local (hash) join of the entire small table with each individual partition of the larger table. ## Query Planning The query planner chooses the broadcast join as its join strategy if one of the sides of the join is smaller than a preconfigured broadcasting threshold (set to 10 MiB by default, but is user-configurable). **Note:** If the smaller side of the join is the right side, we invert the join for planning and scheduling simplicity so we can always broadcast the left side; we then swap back to the correct join ordering when performing the local joins. This means that we always form the probe table on the left side of the join; a future optimization (applicable to both the broadcast join and the hash join) would be to have local joins build the probe table on the smaller side while preserving the expected column ordering. We would still need to always build the probe table on the left side of the join if we need to preserve the row-ordering of the right side of the join, e.g. if the right side of the join is range-partitioned AND we're doing a broadcast join. ## Query Scheduling All partitions for the broadcasting side of the join are first materialized. Then, as each partition on the receiving side of the join materialize, we dispatch a hash join task joining all broadcaster partitions with that single receiving-side partition. ## TODOs - [x] Test coverage. - [ ] (Follow-up?) TPC-H benchmarking demonstrating speedup due to use of broadcast join. - [ ] (Follow-up) In local joins, build the probe table on the smaller side of the join. - [ ] (Follow-up) Add table size approximations for operators that affect cardinality. --- daft/context.py | 8 +- daft/daft.pyi | 5 +- daft/dataframe/dataframe.py | 30 +++- daft/execution/execution_step.py | 13 +- daft/execution/physical_plan.py | 130 +++++++++++++++++- daft/execution/rust_physical_plan_shim.py | 24 +++- daft/io/file_path.py | 3 + daft/logical/builder.py | 6 +- daft/runners/partitioning.py | 7 + daft/runners/pyrunner.py | 10 +- daft/runners/ray_runner.py | 55 +++++--- src/common/daft-config/src/lib.rs | 6 +- src/common/daft-config/src/python.rs | 15 +- src/daft-micropartition/src/ops/join.rs | 1 + src/daft-plan/src/builder.rs | 4 + .../src/physical_ops/broadcast_join.rs | 37 +++++ .../physical_ops/{join.rs => hash_join.rs} | 4 +- src/daft-plan/src/physical_ops/mod.rs | 6 +- src/daft-plan/src/physical_plan.rs | 105 +++++++++++++- src/daft-plan/src/planner.rs | 43 +++++- src/daft-plan/src/source_info/mod.rs | 3 + src/daft-table/src/ops/joins/hash_join.rs | 1 + tests/dataframe/test_joins.py | 24 +++- 23 files changed, 482 insertions(+), 58 deletions(-) create mode 100644 src/daft-plan/src/physical_ops/broadcast_join.rs rename src/daft-plan/src/physical_ops/{join.rs => hash_join.rs} (94%) diff --git a/daft/context.py b/daft/context.py index e607eb553d..8b1b16fe7a 100644 --- a/daft/context.py +++ b/daft/context.py @@ -195,6 +195,7 @@ def set_config( config: PyDaftConfig | None = None, merge_scan_tasks_min_size_bytes: int | None = None, merge_scan_tasks_max_size_bytes: int | None = None, + broadcast_join_size_bytes_threshold: int | None = None, ) -> DaftContext: """Globally sets various configuration parameters which control various aspects of Daft execution @@ -203,10 +204,12 @@ def set_config( that the old (current) config should be used. merge_scan_tasks_min_size_bytes: Minimum size in bytes when merging ScanTasks when reading files from storage. Increasing this value will make Daft perform more merging of files into a single partition before yielding, - which leads to bigger but fewer partitions. (Defaults to 64MB) + which leads to bigger but fewer partitions. (Defaults to 64 MiB) merge_scan_tasks_max_size_bytes: Maximum size in bytes when merging ScanTasks when reading files from storage. Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but - fewer partitions. (Defaults to 512MB) + fewer partitions. (Defaults to 512 MiB) + broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used. + Default is 10 MiB. """ ctx = get_context() if ctx._runner is not None: @@ -220,6 +223,7 @@ def set_config( new_daft_config = old_daft_config.with_config_values( merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes, merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes, + broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold, ) ctx.daft_config = new_daft_config diff --git a/daft/daft.pyi b/daft/daft.pyi index 47240b6019..7fe6f6a4b4 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1058,7 +1058,7 @@ class LogicalPlanBuilder: @staticmethod def in_memory_scan( - partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int + partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int, size_bytes: int ) -> LogicalPlanBuilder: ... @staticmethod def table_scan_with_scan_operator(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ... @@ -1102,11 +1102,14 @@ class PyDaftConfig: self, merge_scan_tasks_min_size_bytes: int | None = None, merge_scan_tasks_max_size_bytes: int | None = None, + broadcast_join_size_bytes_threshold: int | None = None, ) -> PyDaftConfig: ... @property def merge_scan_tasks_min_size_bytes(self): ... @property def merge_scan_tasks_max_size_bytes(self): ... + @property + def broadcast_join_size_bytes_threshold(self): ... def build_type() -> str: ... def version() -> str: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index ebcd8ddcc2..c4b729e039 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -85,10 +85,16 @@ def _builder(self) -> LogicalPlanBuilder: return self.__builder else: num_partitions = self._result_cache.num_partitions() + size_bytes = self._result_cache.size_bytes() # Partition set should always be set on cache entry. - assert num_partitions is not None, "Partition set should always be set on cache entry" + assert ( + num_partitions is not None and size_bytes is not None + ), "Partition set should always be set on cache entry" return self.__builder.from_in_memory_scan( - self._result_cache, self.__builder.schema(), num_partitions=num_partitions + self._result_cache, + self.__builder.schema(), + num_partitions=num_partitions, + size_bytes=size_bytes, ) def _get_current_builder(self) -> LogicalPlanBuilder: @@ -273,7 +279,11 @@ def _from_tables(cls, *parts: MicroPartition) -> "DataFrame": context = get_context() cache_entry = context.runner().put_partition_set_into_cache(result_pset) - builder = LogicalPlanBuilder.from_in_memory_scan(cache_entry, parts[0].schema(), result_pset.num_partitions()) + size_bytes = result_pset.size_bytes() + assert size_bytes is not None, "In-memory data should always have non-None size in bytes" + builder = LogicalPlanBuilder.from_in_memory_scan( + cache_entry, parts[0].schema(), result_pset.num_partitions(), size_bytes + ) return cls(builder) ### @@ -1233,8 +1243,13 @@ def _from_ray_dataset(cls, ds: "RayDataset") -> "DataFrame": partition_set, schema = ray_runner_io.partition_set_from_ray_dataset(ds) cache_entry = context.runner().put_partition_set_into_cache(partition_set) + size_bytes = partition_set.size_bytes() + assert size_bytes is not None, "In-memory data should always have non-None size in bytes" builder = LogicalPlanBuilder.from_in_memory_scan( - cache_entry, schema=schema, num_partitions=partition_set.num_partitions() + cache_entry, + schema=schema, + num_partitions=partition_set.num_partitions(), + size_bytes=size_bytes, ) return cls(builder) @@ -1300,8 +1315,13 @@ def _from_dask_dataframe(cls, ddf: "dask.DataFrame") -> "DataFrame": partition_set, schema = ray_runner_io.partition_set_from_dask_dataframe(ddf) cache_entry = context.runner().put_partition_set_into_cache(partition_set) + size_bytes = partition_set.size_bytes() + assert size_bytes is not None, "In-memory data should always have non-None size in bytes" builder = LogicalPlanBuilder.from_in_memory_scan( - cache_entry, schema=schema, num_partitions=partition_set.num_partitions() + cache_entry, + schema=schema, + num_partitions=partition_set.num_partitions(), + size_bytes=size_bytes, ) return cls(builder) diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 672aa0def7..cf33efc2a0 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -629,12 +629,23 @@ class Join(SingleOutputInstruction): left_on: ExpressionsProjection right_on: ExpressionsProjection how: JoinType + is_swapped: bool def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: return self._join(inputs) def _join(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - [left, right] = inputs + # All inputs except for the last are the left side of the join, in order to support left-broadcasted joins. + *lefts, right = inputs + if len(lefts) > 1: + # NOTE: MicroPartition concats don't concatenate the underlying column arrays, since MicroPartitions are chunked. + left = MicroPartition.concat(lefts) + else: + left = lefts[0] + if self.is_swapped: + # Swap left/right back. + # We don't need to swap left_on and right_on since those were never swapped in the first place. + left, right = right, left result = left.join( right, left_on=self.left_on, diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index a25f3511b4..40868673ec 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -186,7 +186,7 @@ def pipeline_instruction( ) -def join( +def hash_join( left_plan: InProgressPhysicalPlan[PartitionT], right_plan: InProgressPhysicalPlan[PartitionT], left_on: ExpressionsProjection, @@ -233,6 +233,7 @@ def join( left_on=left_on, right_on=right_on, how=how, + is_swapped=False, ) ) yield join_step @@ -271,6 +272,133 @@ def join( return +def _create_join_step( + broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]], + receiver_part: SingleOutputPartitionTask[PartitionT], + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + how: JoinType, + is_swapped: bool, +) -> PartitionTaskBuilder[PartitionT]: + # Calculate memory request for task. + broadcaster_size_bytes_ = 0 + broadcaster_partitions = [] + broadcaster_partition_metadatas = [] + null_count = 0 + for next_broadcaster in broadcaster_parts: + next_broadcaster_partition_metadata = next_broadcaster.partition_metadata() + if next_broadcaster_partition_metadata is None or next_broadcaster_partition_metadata.size_bytes is None: + null_count += 1 + else: + broadcaster_size_bytes_ += next_broadcaster_partition_metadata.size_bytes + broadcaster_partitions.append(next_broadcaster.partition()) + broadcaster_partition_metadatas.append(next_broadcaster_partition_metadata) + if null_count == len(broadcaster_parts): + broadcaster_size_bytes = None + elif null_count > 0: + # Impute null size estimates with mean of non-null estimates. + broadcaster_size_bytes = broadcaster_size_bytes_ + math.ceil( + null_count * broadcaster_size_bytes_ / (len(broadcaster_parts) - null_count) + ) + else: + broadcaster_size_bytes = broadcaster_size_bytes_ + receiver_size_bytes = receiver_part.partition_metadata().size_bytes + if broadcaster_size_bytes is None and receiver_size_bytes is None: + size_bytes = None + elif broadcaster_size_bytes is None and receiver_size_bytes is not None: + # Use 1.25x the receiver side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side. + size_bytes = int(1.25 * receiver_size_bytes) + elif receiver_size_bytes is None and broadcaster_size_bytes is not None: + # Use 4x the broadcaster side as the memory request, assuming that receiver side is ~4x larger than the broadcaster side. + size_bytes = 4 * broadcaster_size_bytes + elif broadcaster_size_bytes is not None and receiver_size_bytes is not None: + size_bytes = broadcaster_size_bytes + receiver_size_bytes + + return PartitionTaskBuilder[PartitionT]( + inputs=broadcaster_partitions + [receiver_part.partition()], + partial_metadatas=list(broadcaster_partition_metadatas + [receiver_part.partition_metadata()]), + resource_request=ResourceRequest(memory_bytes=size_bytes), + ).add_instruction( + instruction=execution_step.Join( + left_on=left_on, + right_on=right_on, + how=how, + is_swapped=is_swapped, + ) + ) + + +def broadcast_join( + broadcaster_plan: InProgressPhysicalPlan[PartitionT], + receiver_plan: InProgressPhysicalPlan[PartitionT], + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + how: JoinType, + is_swapped: bool, +) -> InProgressPhysicalPlan[PartitionT]: + """Broadcast join all partitions from the broadcaster child plan to each partition in the receiver child plan.""" + + # Materialize the steps from the broadcaster and receiver sources to get partitions. + # As the receiver-side materializations complete, emit new steps to join each broadcaster and receiver partition. + stage_id = next(stage_id_counter) + broadcaster_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque() + broadcaster_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque() + + # First, fully materialize the broadcasting side (broadcaster side) of the join. + while True: + # Moved completed partition tasks in the broadcaster side of the join to the materialized partition set. + while broadcaster_requests and broadcaster_requests[0].done(): + broadcaster_parts.append(broadcaster_requests.popleft()) + + # Execute single child step to pull in more broadcaster-side partitions. + try: + step = next(broadcaster_plan) + if isinstance(step, PartitionTaskBuilder): + step = step.finalize_partition_task_single_output(stage_id=stage_id) + broadcaster_requests.append(step) + yield step + except StopIteration: + if broadcaster_requests: + logger.debug( + "broadcast join blocked on completion of broadcasting side of join.\n broadcaster sources: %s", + broadcaster_requests, + ) + yield None + else: + break + + # Second, broadcast materialized partitions to receiver side of join, as it materializes. + receiver_requests: deque[SingleOutputPartitionTask[PartitionT]] = deque() + + while True: + receiver_parts: deque[SingleOutputPartitionTask[PartitionT]] = deque() + # Moved completed partition tasks in the receiver side of the join to the materialized partition set. + while receiver_requests and receiver_requests[0].done(): + receiver_parts.append(receiver_requests.popleft()) + + # Emit join steps for newly materialized partitions. + # Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop. + for receiver_part in receiver_parts: + yield _create_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped) + + # Execute single child step to pull in more input partitions. + try: + step = next(receiver_plan) + if isinstance(step, PartitionTaskBuilder): + step = step.finalize_partition_task_single_output(stage_id=stage_id) + receiver_requests.append(step) + yield step + except StopIteration: + if receiver_requests: + logger.debug( + "broadcast join blocked on completion of receiver side of join.\n receiver sources: %s", + receiver_requests, + ) + yield None + else: + return + + def concat( top_plan: InProgressPhysicalPlan[PartitionT], bottom_plan: InProgressPhysicalPlan[PartitionT] ) -> InProgressPhysicalPlan[PartitionT]: diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 4e3644523d..bb1016088a 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -188,7 +188,7 @@ def reduce_merge( return physical_plan.reduce(input, reduce_instruction) -def join( +def hash_join( input: physical_plan.InProgressPhysicalPlan[PartitionT], right: physical_plan.InProgressPhysicalPlan[PartitionT], left_on: list[PyExpr], @@ -197,7 +197,7 @@ def join( ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) - return physical_plan.join( + return physical_plan.hash_join( left_plan=input, right_plan=right, left_on=left_on_expr_proj, @@ -206,6 +206,26 @@ def join( ) +def broadcast_join( + broadcaster: physical_plan.InProgressPhysicalPlan[PartitionT], + receiver: physical_plan.InProgressPhysicalPlan[PartitionT], + left_on: list[PyExpr], + right_on: list[PyExpr], + join_type: JoinType, + is_swapped: bool, +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) + right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) + return physical_plan.broadcast_join( + broadcaster_plan=broadcaster, + receiver_plan=receiver, + left_on=left_on_expr_proj, + right_on=right_on_expr_proj, + how=join_type, + is_swapped=is_swapped, + ) + + def write_file( input: physical_plan.InProgressPhysicalPlan[PartitionT], file_format: FileFormat, diff --git a/daft/io/file_path.py b/daft/io/file_path.py index 079fa9ef5a..7fb1905da5 100644 --- a/daft/io/file_path.py +++ b/daft/io/file_path.py @@ -48,9 +48,12 @@ def from_glob_path(path: str, io_config: Optional[IOConfig] = None) -> DataFrame file_infos_table = MicroPartition._from_pytable(file_infos.to_table()) partition = LocalPartitionSet({0: file_infos_table}) cache_entry = context.runner().put_partition_set_into_cache(partition) + size_bytes = partition.size_bytes() + assert size_bytes is not None, "In-memory data should always have non-None size in bytes" builder = LogicalPlanBuilder.from_in_memory_scan( cache_entry, schema=file_infos_table.schema(), num_partitions=partition.num_partitions(), + size_bytes=size_bytes, ) return DataFrame(builder) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index f948f666a6..5f9f59dc08 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -74,9 +74,11 @@ def optimize(self) -> LogicalPlanBuilder: @classmethod def from_in_memory_scan( - cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int + cls, partition: PartitionCacheEntry, schema: Schema, num_partitions: int, size_bytes: int ) -> LogicalPlanBuilder: - builder = _LogicalPlanBuilder.in_memory_scan(partition.key, partition, schema._schema, num_partitions) + builder = _LogicalPlanBuilder.in_memory_scan( + partition.key, partition, schema._schema, num_partitions, size_bytes + ) return cls(builder) @classmethod diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index ec6c04085c..45e86c0d54 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -175,6 +175,10 @@ def has_partition(self, idx: PartID) -> bool: def __len__(self) -> int: raise NotImplementedError() + @abstractmethod + def size_bytes(self) -> int | None: + raise NotImplementedError() + @abstractmethod def num_partitions(self) -> int: raise NotImplementedError() @@ -208,6 +212,9 @@ def __setstate__(self, key): def num_partitions(self) -> int | None: return self.value.num_partitions() if self.value is not None else None + def size_bytes(self) -> int | None: + return self.value.size_bytes() if self.value is not None else None + class PartitionSetCache: def __init__(self) -> None: diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 55076c2156..34c9d5e599 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -64,7 +64,15 @@ def has_partition(self, idx: PartID) -> bool: return idx in self._partitions def __len__(self) -> int: - return sum([len(self._partitions[pid]) for pid in self._partitions]) + return sum(len(partition) for partition in self._partitions.values()) + + def size_bytes(self) -> int | None: + size_bytes_ = [partition.size_bytes() for partition in self._partitions.values()] + size_bytes: list[int] = [size for size in size_bytes_ if size is not None] + if len(size_bytes) != len(size_bytes_): + return None + else: + return sum(size_bytes) def num_partitions(self) -> int: return len(self._partitions) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index b845296c7a..06fde27c76 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -212,7 +212,15 @@ def has_partition(self, idx: PartID) -> bool: return idx in self._results def __len__(self) -> int: - return sum([self._results[pid].metadata().num_rows for pid in self._results]) + return sum(result.metadata().num_rows for result in self._results.values()) + + def size_bytes(self) -> int | None: + size_bytes_ = [result.metadata().size_bytes for result in self._results.values()] + size_bytes: list[int] = [size for size in size_bytes_ if size is not None] + if len(size_bytes) != len(size_bytes_): + return None + else: + return sum(size_bytes) def num_partitions(self) -> int: return len(self._results) @@ -302,7 +310,7 @@ def partition_set_from_ray_dataset( RayPartitionSet( _daft_config_objref=self.daft_config_objref, _results={ - i: RayMaterializedResult(obj, _daft_config_objref=self.daft_config_objref) + i: RayMaterializedResult(obj, daft_config_objref=self.daft_config_objref) for i, obj in enumerate(daft_vpartitions) }, ), @@ -335,7 +343,7 @@ def partition_set_from_dask_dataframe( RayPartitionSet( _daft_config_objref=self.daft_config_objref, _results={ - i: RayMaterializedResult(obj, _daft_config_objref=self.daft_config_objref) + i: RayMaterializedResult(obj, daft_config_objref=self.daft_config_objref) for i, obj in enumerate(daft_vpartitions) }, ), @@ -412,9 +420,9 @@ def reduce_and_fanout( @ray.remote -def get_meta(daft_config: PyDaftConfig, partition: MicroPartition) -> PartitionMetadata: +def get_metas(daft_config: PyDaftConfig, *partitions: MicroPartition) -> list[PartitionMetadata]: set_config(daft_config) - return PartitionMetadata.from_table(partition) + return [PartitionMetadata.from_table(partition) for partition in partitions] def _ray_num_cpus_provider(ttl_seconds: int = 1) -> Generator[int, None, None]: @@ -578,7 +586,7 @@ def place_in_queue(item): assert isinstance(next_step, SingleOutputPartitionTask) next_step.set_result( [ - RayMaterializedResult(partition, _daft_config_objref=self.daft_config_objref) + RayMaterializedResult(partition, daft_config_objref=self.daft_config_objref) for partition in next_step.inputs ] ) @@ -617,7 +625,6 @@ def place_in_queue(item): dispatch = datetime.now() completed_task_ids = [] for wait_for in ("next_one", "next_batch"): - if not is_active(): break @@ -704,10 +711,9 @@ def _build_partitions(daft_config_objref: ray.ObjectRef, task: PartitionTask[ray task.set_result( [ RayMaterializedResult( - _partition=partition, - _daft_config_objref=daft_config_objref, - _metadatas=metadatas_accessor, - _metadata_index=i, + partition=partition, + metadatas=metadatas_accessor, + metadata_idx=i, ) for i, partition in enumerate(partitions) ] @@ -824,7 +830,7 @@ def put_partition_set_into_cache(self, pset: PartitionSet) -> PartitionCacheEntr pset = RayPartitionSet( _daft_config_objref=self.daft_config_objref, _results={ - pid: RayMaterializedResult(ray.put(val), _daft_config_objref=self.daft_config_objref) + pid: RayMaterializedResult(ray.put(val), daft_config_objref=self.daft_config_objref) for pid, val in pset._partitions.items() }, ) @@ -835,12 +841,22 @@ def runner_io(self) -> RayRunnerIO: return RayRunnerIO(daft_config_objref=self.daft_config_objref) -@dataclass(frozen=True) class RayMaterializedResult(MaterializedResult[ray.ObjectRef]): - _partition: ray.ObjectRef - _daft_config_objref: ray.ObjectRef - _metadatas: PartitionMetadataAccessor | None = None - _metadata_index: int | None = None + def __init__( + self, + partition: ray.ObjectRef, + daft_config_objref: ray.ObjectRef | None = None, + metadatas: PartitionMetadataAccessor | None = None, + metadata_idx: int | None = None, + ): + self._partition = partition + if metadatas is None: + assert metadata_idx is None + assert daft_config_objref is not None + metadatas = PartitionMetadataAccessor(get_metas.remote(daft_config_objref, self._partition)) + metadata_idx = 0 + self._metadatas = metadatas + self._metadata_idx = metadata_idx def partition(self) -> ray.ObjectRef: return self._partition @@ -849,10 +865,7 @@ def vpartition(self) -> MicroPartition: return ray.get(self._partition) def metadata(self) -> PartitionMetadata: - if self._metadatas is not None and self._metadata_index is not None: - return self._metadatas.get_index(self._metadata_index) - else: - return ray.get(get_meta.remote(self._daft_config_objref, self._partition)) + return self._metadatas.get_index(self._metadata_idx) def cancel(self) -> None: return ray.cancel(self._partition) diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index b71280ee36..7134f65893 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -4,13 +4,15 @@ use serde::{Deserialize, Serialize}; pub struct DaftConfig { pub merge_scan_tasks_min_size_bytes: usize, pub merge_scan_tasks_max_size_bytes: usize, + pub broadcast_join_size_bytes_threshold: usize, } impl Default for DaftConfig { fn default() -> Self { DaftConfig { - merge_scan_tasks_min_size_bytes: 64 * 1024 * 1024, // 64MB - merge_scan_tasks_max_size_bytes: 512 * 1024 * 1024, // 512MB + merge_scan_tasks_min_size_bytes: 64 * 1024 * 1024, // 64 MiB + merge_scan_tasks_max_size_bytes: 512 * 1024 * 1024, // 512 MiB + broadcast_join_size_bytes_threshold: 10 * 1024 * 1024, // 10 MiB } } } diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index f327e5a55d..353858152c 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -19,9 +19,10 @@ impl PyDaftConfig { } fn with_config_values( - &mut self, + &self, merge_scan_tasks_min_size_bytes: Option, merge_scan_tasks_max_size_bytes: Option, + broadcast_join_size_bytes_threshold: Option, ) -> PyResult { let mut config = self.config.as_ref().clone(); @@ -31,22 +32,30 @@ impl PyDaftConfig { if let Some(merge_scan_tasks_min_size_bytes) = merge_scan_tasks_min_size_bytes { config.merge_scan_tasks_min_size_bytes = merge_scan_tasks_min_size_bytes; } + if let Some(broadcast_join_size_bytes_threshold) = broadcast_join_size_bytes_threshold { + config.broadcast_join_size_bytes_threshold = broadcast_join_size_bytes_threshold; + } Ok(PyDaftConfig { config: Arc::new(config), }) } - #[getter(merge_scan_tasks_min_size_bytes)] + #[getter] fn get_merge_scan_tasks_min_size_bytes(&self) -> PyResult { Ok(self.config.merge_scan_tasks_min_size_bytes) } - #[getter(merge_scan_tasks_max_size_bytes)] + #[getter] fn get_merge_scan_tasks_max_size_bytes(&self) -> PyResult { Ok(self.config.merge_scan_tasks_max_size_bytes) } + #[getter] + fn get_broadcast_join_size_bytes_threshold(&self) -> PyResult { + Ok(self.config.broadcast_join_size_bytes_threshold) + } + fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec,))> { let bin_data = bincode::serialize(self.config.as_ref()) .expect("DaftConfig should be serializable to bytes"); diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index f8d5efbe5f..86338a2380 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -41,6 +41,7 @@ impl MicroPartition { return Ok(Self::empty(Some(join_schema.into()))); } + // TODO(Clark): Elide concatenations where possible by doing a chunk-aware local table join. let lt = self.concat_or_get(io_stats.clone())?; let rt = right.concat_or_get(io_stats)?; diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 9e109cde00..874871c4fc 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -57,12 +57,14 @@ impl LogicalPlanBuilder { cache_entry: PyObject, schema: Arc, num_partitions: usize, + size_bytes: usize, ) -> DaftResult { let source_info = SourceInfo::InMemoryInfo(InMemoryInfo::new( schema.clone(), partition_key.into(), cache_entry, num_partitions, + size_bytes, )); let logical_plan: LogicalPlan = logical_ops::Source::new(schema.clone(), source_info.into()).into(); @@ -285,12 +287,14 @@ impl PyLogicalPlanBuilder { cache_entry: &PyAny, schema: PySchema, num_partitions: usize, + size_bytes: usize, ) -> PyResult { Ok(LogicalPlanBuilder::in_memory_scan( partition_key, cache_entry.to_object(cache_entry.py()), schema.into(), num_partitions, + size_bytes, )? .into()) } diff --git a/src/daft-plan/src/physical_ops/broadcast_join.rs b/src/daft-plan/src/physical_ops/broadcast_join.rs new file mode 100644 index 0000000000..4b555439d6 --- /dev/null +++ b/src/daft-plan/src/physical_ops/broadcast_join.rs @@ -0,0 +1,37 @@ +use std::sync::Arc; + +use daft_dsl::Expr; + +use crate::{physical_plan::PhysicalPlan, JoinType}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct BroadcastJoin { + // Upstream node. + pub broadcaster: Arc, + pub receiver: Arc, + pub left_on: Vec, + pub right_on: Vec, + pub join_type: JoinType, + pub is_swapped: bool, +} + +impl BroadcastJoin { + pub(crate) fn new( + broadcaster: Arc, + receiver: Arc, + left_on: Vec, + right_on: Vec, + join_type: JoinType, + is_swapped: bool, + ) -> Self { + Self { + broadcaster, + receiver, + left_on, + right_on, + join_type, + is_swapped, + } + } +} diff --git a/src/daft-plan/src/physical_ops/join.rs b/src/daft-plan/src/physical_ops/hash_join.rs similarity index 94% rename from src/daft-plan/src/physical_ops/join.rs rename to src/daft-plan/src/physical_ops/hash_join.rs index 79db5097aa..adfd546e60 100644 --- a/src/daft-plan/src/physical_ops/join.rs +++ b/src/daft-plan/src/physical_ops/hash_join.rs @@ -6,7 +6,7 @@ use crate::{physical_plan::PhysicalPlan, JoinType}; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Join { +pub struct HashJoin { // Upstream node. pub left: Arc, pub right: Arc, @@ -15,7 +15,7 @@ pub struct Join { pub join_type: JoinType, } -impl Join { +impl HashJoin { pub(crate) fn new( left: Arc, right: Arc, diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index 09dc707511..e55aca8dd0 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -1,4 +1,5 @@ mod agg; +mod broadcast_join; mod coalesce; mod concat; mod csv; @@ -6,9 +7,9 @@ mod explode; mod fanout; mod filter; mod flatten; +mod hash_join; #[cfg(feature = "python")] mod in_memory; -mod join; mod json; mod limit; mod parquet; @@ -19,6 +20,7 @@ mod sort; mod split; pub use agg::Aggregate; +pub use broadcast_join::BroadcastJoin; pub use coalesce::Coalesce; pub use concat::Concat; pub use csv::{TabularScanCsv, TabularWriteCsv}; @@ -26,9 +28,9 @@ pub use explode::Explode; pub use fanout::{FanoutByHash, FanoutByRange, FanoutRandom}; pub use filter::Filter; pub use flatten::Flatten; +pub use hash_join::HashJoin; #[cfg(feature = "python")] pub use in_memory::InMemoryScan; -pub use join::Join; pub use json::{TabularScanJson, TabularWriteJson}; pub use limit::Limit; pub use parquet::{TabularScanParquet, TabularWriteParquet}; diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 06f0998286..68112897ee 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -52,7 +52,8 @@ pub enum PhysicalPlan { ReduceMerge(ReduceMerge), Aggregate(Aggregate), Concat(Concat), - Join(Join), + HashJoin(HashJoin), + BroadcastJoin(BroadcastJoin), TabularWriteParquet(TabularWriteParquet), TabularWriteJson(TabularWriteJson), TabularWriteCsv(TabularWriteCsv), @@ -135,7 +136,7 @@ impl PhysicalPlan { None, ) .into(), - Self::Join(Join { + Self::HashJoin(HashJoin { left, right, left_on, @@ -159,11 +160,76 @@ impl PhysicalPlan { .into(), } } + Self::BroadcastJoin(BroadcastJoin { + receiver: right, .. + }) => right.partition_spec(), Self::TabularWriteParquet(TabularWriteParquet { input, .. }) => input.partition_spec(), Self::TabularWriteCsv(TabularWriteCsv { input, .. }) => input.partition_spec(), Self::TabularWriteJson(TabularWriteJson { input, .. }) => input.partition_spec(), } } + + pub fn approximate_size_bytes(&self) -> Option { + match self { + #[cfg(feature = "python")] + Self::InMemoryScan(InMemoryScan { in_memory_info, .. }) => { + Some(in_memory_info.size_bytes) + } + Self::TabularScan(TabularScan { scan_tasks, .. }) => scan_tasks + .iter() + .map(|scan_task| scan_task.size_bytes()) + .sum::>(), + // Assume no row/column pruning in cardinality-affecting operations. + // TODO(Clark): Estimate row/column pruning to get a better size approximation. + Self::Filter(Filter { input, .. }) + | Self::Limit(Limit { input, .. }) + | Self::Project(Project { input, .. }) => input.approximate_size_bytes(), + // Assume ~the same size in bytes for explodes. + // TODO(Clark): Improve this estimate. + Self::Explode(Explode { input, .. }) => input.approximate_size_bytes(), + // Propagate child approximation for operations that don't affect cardinality. + Self::Coalesce(Coalesce { input, .. }) + | Self::FanoutByHash(FanoutByHash { input, .. }) + | Self::FanoutByRange(FanoutByRange { input, .. }) + | Self::FanoutRandom(FanoutRandom { input, .. }) + | Self::Flatten(Flatten { input, .. }) + | Self::ReduceMerge(ReduceMerge { input, .. }) + | Self::Sort(Sort { input, .. }) + | Self::Split(Split { input, .. }) => input.approximate_size_bytes(), + Self::Concat(Concat { input, other }) => { + input.approximate_size_bytes().and_then(|input_size| { + other + .approximate_size_bytes() + .map(|other_size| input_size + other_size) + }) + } + // Assume a simple sum of the sizes of both sides of the join for the post-join size. + // TODO(Clark): This will double-count join key columns, we should ensure that these are only counted once. + Self::BroadcastJoin(BroadcastJoin { + broadcaster: left, + receiver: right, + .. + }) + | Self::HashJoin(HashJoin { left, right, .. }) => { + left.approximate_size_bytes().and_then(|left_size| { + right + .approximate_size_bytes() + .map(|right_size| left_size + right_size) + }) + } + // TODO(Clark): Approximate post-aggregation sizes via grouping estimates + aggregation type. + Self::Aggregate(_) => None, + // No size approximation support for legacy I/O. + Self::TabularScanParquet(_) | Self::TabularScanCsv(_) | Self::TabularScanJson(_) => { + None + } + // Post-write DataFrame will contain paths to files that were written. + // TODO(Clark): Estimate output size via root directory and estimates for # of partitions given partitioning column. + Self::TabularWriteParquet(_) | Self::TabularWriteCsv(_) | Self::TabularWriteJson(_) => { + None + } + } + } } /// A work scheduler for physical plans. @@ -582,7 +648,7 @@ impl PhysicalPlan { .call1((upstream_input_iter, upstream_other_iter))?; Ok(py_iter.into()) } - PhysicalPlan::Join(Join { + PhysicalPlan::HashJoin(HashJoin { left, right, left_on, @@ -602,13 +668,44 @@ impl PhysicalPlan { .collect(); let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? - .getattr(pyo3::intern!(py, "join"))? + .getattr(pyo3::intern!(py, "hash_join"))? + .call1(( + upstream_left_iter, + upstream_right_iter, + left_on_pyexprs, + right_on_pyexprs, + *join_type, + ))?; + Ok(py_iter.into()) + } + PhysicalPlan::BroadcastJoin(BroadcastJoin { + broadcaster: left, + receiver: right, + left_on, + right_on, + join_type, + is_swapped, + }) => { + let upstream_left_iter = left.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_right_iter = right.to_partition_tasks(py, psets, is_ray_runner)?; + let left_on_pyexprs: Vec = left_on + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let right_on_pyexprs: Vec = right_on + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let py_iter = py + .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "broadcast_join"))? .call1(( upstream_left_iter, upstream_right_iter, left_on_pyexprs, right_on_pyexprs, *join_type, + *is_swapped, ))?; Ok(py_iter.into()) } diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 9e23a1c627..b8d1018b9e 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -487,6 +487,7 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult { let mut left_physical = plan(left, cfg.clone())?; let mut right_physical = plan(right, cfg.clone())?; + let left_pspec = left_physical.partition_spec(); let right_pspec = right_physical.partition_spec(); let num_partitions = max(left_pspec.num_partitions, right_pspec.num_partitions); @@ -500,8 +501,44 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult { + if right_size_bytes < left_size_bytes { + (Some(right_size_bytes), true) + } else { + (Some(left_size_bytes), false) + } + } + (Some(left_size_bytes), None) => (Some(left_size_bytes), false), + (None, Some(right_size_bytes)) => (Some(right_size_bytes), true), + (None, None) => (None, false), + }; + let is_larger_partitioned = if do_swap { + is_right_partitioned + } else { + is_left_partitioned + }; + // If larger table is not already partitioned on the join key AND the smaller table is under broadcast size threshold, use broadcast join. + if !is_larger_partitioned && let Some(smaller_size_bytes) = smaller_size_bytes && smaller_size_bytes <= cfg.broadcast_join_size_bytes_threshold { + if do_swap { + // These will get swapped back when doing the actual local joins. + (left_physical, right_physical) = (right_physical, left_physical); + } + return Ok(PhysicalPlan::BroadcastJoin(BroadcastJoin::new(left_physical.into(), right_physical.into(), left_on.clone(), right_on.clone(), *join_type, do_swap))); + } if (num_partitions > 1 || left_pspec.num_partitions != num_partitions) - && left_pspec != new_left_pspec + && !is_left_partitioned { let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( left_physical.into(), @@ -511,7 +548,7 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult 1 || right_pspec.num_partitions != num_partitions) - && right_pspec != new_right_pspec + && !is_right_partitioned { let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( right_physical.into(), @@ -520,7 +557,7 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult Self { Self { source_schema, cache_key, cache_entry, num_partitions, + size_bytes, } } } diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs index 68c3b7dfaf..56b80aedb9 100644 --- a/src/daft-table/src/ops/joins/hash_join.rs +++ b/src/daft-table/src/ops/joins/hash_join.rs @@ -43,6 +43,7 @@ pub(super) fn hash_inner_join(left: &Table, right: &Table) -> DaftResult<(Series )); } + // TODO(Clark): Build the probe table on the smaller table, rather than always building it on the left table. let probe_table = left.to_probe_hash_table()?; let r_hashes = right.hash_rows()?; diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 30bbd16a45..a538bb5c53 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -3,13 +3,25 @@ import pyarrow as pa import pytest +import daft from daft.datatype import DataType from daft.errors import ExpressionTypeError from tests.utils import sort_arrow_table +@pytest.fixture(params=[False, True]) +def broadcast_join_enabled(request): + # Toggles between default broadcast join threshold (10 MiB), and a threshold of 0, which disables broadcast joins. + broadcast_threshold = 10 * 1024 * 1024 if request.param else 0 + old_context = daft.context.pop_context() + try: + yield daft.context.set_config(broadcast_join_size_bytes_threshold=broadcast_threshold) + finally: + daft.context.set_context(old_context) + + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) -def test_multicol_joins(make_df, n_partitions: int): +def test_multicol_joins(broadcast_join_enabled, make_df, n_partitions: int): df = make_df( { "A": [1, 2, 3], @@ -32,7 +44,7 @@ def test_multicol_joins(make_df, n_partitions: int): @pytest.mark.parametrize("n_partitions", [1, 2, 4]) -def test_limit_after_join(make_df, n_partitions: int): +def test_limit_after_join(broadcast_join_enabled, make_df, n_partitions: int): data = { "A": [1, 2, 3], } @@ -59,7 +71,7 @@ def test_limit_after_join(make_df, n_partitions: int): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join(make_df, repartition_nparts): +def test_inner_join(broadcast_join_enabled, make_df, repartition_nparts): daft_df = make_df( { "id": [1, None, 3], @@ -87,7 +99,7 @@ def test_inner_join(make_df, repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join_multikey(make_df, repartition_nparts): +def test_inner_join_multikey(broadcast_join_enabled, make_df, repartition_nparts): daft_df = make_df( { "id": [1, None, None], @@ -118,7 +130,7 @@ def test_inner_join_multikey(make_df, repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join_all_null(make_df, repartition_nparts): +def test_inner_join_all_null(broadcast_join_enabled, make_df, repartition_nparts): daft_df = make_df( { "id": [None, None, None], @@ -145,7 +157,7 @@ def test_inner_join_all_null(make_df, repartition_nparts): ) -def test_inner_join_null_type_column(make_df): +def test_inner_join_null_type_column(broadcast_join_enabled, make_df): daft_df = make_df( { "id": [None, None, None],